-
Notifications
You must be signed in to change notification settings - Fork 3
Add base and decision task handler #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
3fd5244
add base and decision task handler
timl3136 7818de9
add unit test
timl3136 5d759c3
lint
timl3136 75b8664
improve context management
timl3136 ea56208
lint
timl3136 08b0e3b
remove context activate in handler
timl3136 e269704
Merge branch 'main' into base-handler-1
timl3136 918d128
lint
timl3136 9f9c289
fix test
timl3136 7d6eaec
respond to comments
timl3136 27277e4
lint
timl3136 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import TypeVar, Generic | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar('T') | ||
|
||
class BaseTaskHandler(ABC, Generic[T]): | ||
""" | ||
Base task handler that provides common functionality for processing tasks. | ||
|
||
This abstract class defines the interface and common behavior for task handlers | ||
that process different types of tasks (workflow decisions, activities, etc.). | ||
""" | ||
|
||
def __init__(self, client, task_list: str, identity: str, **options): | ||
""" | ||
Initialize the base task handler. | ||
|
||
Args: | ||
client: The Cadence client instance | ||
task_list: The task list name | ||
identity: Worker identity | ||
**options: Additional options for the handler | ||
""" | ||
self._client = client | ||
self._task_list = task_list | ||
self._identity = identity | ||
self._options = options | ||
|
||
async def handle_task(self, task: T) -> None: | ||
""" | ||
Handle a single task. | ||
|
||
This method provides the base implementation for task handling that includes: | ||
- Error handling | ||
- Cleanup | ||
|
||
Args: | ||
task: The task to handle | ||
""" | ||
try: | ||
# Handle the task implementation | ||
await self._handle_task_implementation(task) | ||
|
||
except Exception as e: | ||
logger.exception(f"Error handling task: {e}") | ||
await self.handle_task_failure(task, e) | ||
|
||
@abstractmethod | ||
async def _handle_task_implementation(self, task: T) -> None: | ||
""" | ||
Handle the actual task implementation. | ||
|
||
Args: | ||
task: The task to handle | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
async def handle_task_failure(self, task: T, error: Exception) -> None: | ||
""" | ||
Handle task processing failure. | ||
|
||
Args: | ||
task: The task that failed | ||
error: The exception that occurred | ||
""" | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import logging | ||
|
||
from cadence.api.v1.common_pb2 import Payload | ||
from cadence.api.v1.service_worker_pb2 import ( | ||
PollForDecisionTaskResponse, | ||
RespondDecisionTaskCompletedRequest, | ||
RespondDecisionTaskFailedRequest | ||
) | ||
from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause | ||
from cadence.client import Client | ||
from cadence.worker._base_task_handler import BaseTaskHandler | ||
from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult | ||
from cadence.workflow import WorkflowInfo | ||
from cadence.worker._registry import Registry | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): | ||
""" | ||
Task handler for processing decision tasks. | ||
|
||
This handler processes decision tasks and generates decisions using the workflow engine. | ||
""" | ||
|
||
def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): | ||
""" | ||
Initialize the decision task handler. | ||
|
||
Args: | ||
client: The Cadence client instance | ||
task_list: The task list name | ||
registry: Registry containing workflow functions | ||
identity: The worker identity | ||
**options: Additional options for the handler | ||
""" | ||
super().__init__(client, task_list, identity, **options) | ||
self._registry = registry | ||
self._workflow_engine: WorkflowEngine | ||
|
||
|
||
async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: | ||
""" | ||
Handle a decision task implementation. | ||
|
||
Args: | ||
task: The decision task to handle | ||
""" | ||
# Extract workflow execution info | ||
workflow_execution = task.workflow_execution | ||
workflow_type = task.workflow_type | ||
|
||
if not workflow_execution or not workflow_type: | ||
logger.error("Decision task missing workflow execution or type. Task: %r", task) | ||
raise ValueError("Missing workflow execution or type") | ||
|
||
workflow_id = workflow_execution.workflow_id | ||
run_id = workflow_execution.run_id | ||
workflow_type_name = workflow_type.name | ||
|
||
logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") | ||
|
||
try: | ||
workflow_func = self._registry.get_workflow(workflow_type_name) | ||
except KeyError: | ||
logger.error(f"Workflow type '{workflow_type_name}' not found in registry") | ||
raise KeyError(f"Workflow type '{workflow_type_name}' not found") | ||
|
||
# Create workflow info and engine | ||
workflow_info = WorkflowInfo( | ||
workflow_type=workflow_type_name, | ||
workflow_domain=self._client.domain, | ||
workflow_id=workflow_id, | ||
workflow_run_id=run_id | ||
) | ||
|
||
self._workflow_engine = WorkflowEngine( | ||
info=workflow_info, | ||
client=self._client, | ||
workflow_func=workflow_func | ||
) | ||
|
||
decision_result = await self._workflow_engine.process_decision(task) | ||
|
||
# Respond with the decisions | ||
await self._respond_decision_task_completed(task, decision_result) | ||
|
||
logger.info(f"Successfully processed decision task for workflow {workflow_id}") | ||
|
||
async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: | ||
""" | ||
Handle decision task processing failure. | ||
|
||
Args: | ||
task: The task that failed | ||
error: The exception that occurred | ||
""" | ||
logger.error(f"Decision task failed: {error}") | ||
|
||
# Determine the failure cause | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION | ||
if isinstance(error, KeyError): | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE | ||
elif isinstance(error, ValueError): | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES | ||
|
||
# Create error details | ||
# TODO: Use a data converter for error details serialization | ||
error_message = str(error).encode('utf-8') | ||
details = Payload(data=error_message) | ||
|
||
# Respond with failure | ||
try: | ||
await self._client.worker_stub.RespondDecisionTaskFailed( | ||
RespondDecisionTaskFailedRequest( | ||
task_token=task.task_token, | ||
cause=cause, | ||
identity=self._identity, | ||
details=details | ||
) | ||
) | ||
logger.info("Decision task failure response sent") | ||
except Exception: | ||
logger.exception("Error responding to decision task failure") | ||
|
||
|
||
async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: | ||
""" | ||
Respond to the service that the decision task has been completed. | ||
|
||
Args: | ||
task: The original decision task | ||
decision_result: The result containing decisions and query results | ||
""" | ||
try: | ||
request = RespondDecisionTaskCompletedRequest( | ||
task_token=task.task_token, | ||
decisions=decision_result.decisions, | ||
identity=self._identity, | ||
return_new_decision_task=True, | ||
force_create_new_decision_task=False | ||
) | ||
|
||
await self._client.worker_stub.RespondDecisionTaskCompleted(request) | ||
logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") | ||
|
||
except Exception: | ||
logger.exception("Error responding to decision task completion") | ||
raise |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Unit tests for BaseTaskHandler class. | ||
""" | ||
|
||
import pytest | ||
from unittest.mock import Mock | ||
|
||
from cadence.worker._base_task_handler import BaseTaskHandler | ||
|
||
|
||
class ConcreteTaskHandler(BaseTaskHandler[str]): | ||
"""Concrete implementation of BaseTaskHandler for testing.""" | ||
|
||
def __init__(self, client, task_list: str, identity: str, **options): | ||
super().__init__(client, task_list, identity, **options) | ||
self._handle_task_implementation_called = False | ||
self._handle_task_failure_called = False | ||
self._last_task: str = "" | ||
self._last_error: Exception | None = None | ||
|
||
async def _handle_task_implementation(self, task: str) -> None: | ||
"""Test implementation of task handling.""" | ||
self._handle_task_implementation_called = True | ||
self._last_task = task | ||
if task == "raise_error": | ||
raise ValueError("Test error") | ||
|
||
async def handle_task_failure(self, task: str, error: Exception) -> None: | ||
"""Test implementation of task failure handling.""" | ||
self._handle_task_failure_called = True | ||
self._last_task = task | ||
self._last_error = error | ||
|
||
|
||
class TestBaseTaskHandler: | ||
"""Test cases for BaseTaskHandler.""" | ||
|
||
def test_initialization(self): | ||
"""Test BaseTaskHandler initialization.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler( | ||
client=client, | ||
task_list="test_task_list", | ||
identity="test_identity", | ||
option1="value1", | ||
option2="value2" | ||
) | ||
|
||
assert handler._client == client | ||
assert handler._task_list == "test_task_list" | ||
assert handler._identity == "test_identity" | ||
assert handler._options == {"option1": "value1", "option2": "value2"} | ||
|
||
@pytest.mark.asyncio | ||
async def test_handle_task_success(self): | ||
"""Test successful task handling.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") | ||
|
||
await handler.handle_task("test_task") | ||
|
||
# Verify implementation was called | ||
assert handler._handle_task_implementation_called | ||
assert not handler._handle_task_failure_called | ||
assert handler._last_task == "test_task" | ||
assert handler._last_error is None | ||
|
||
@pytest.mark.asyncio | ||
async def test_handle_task_failure(self): | ||
"""Test task handling with error.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") | ||
|
||
await handler.handle_task("raise_error") | ||
|
||
# Verify error handling was called | ||
assert handler._handle_task_implementation_called | ||
assert handler._handle_task_failure_called | ||
assert handler._last_task == "raise_error" | ||
assert isinstance(handler._last_error, ValueError) | ||
assert str(handler._last_error) == "Test error" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_abstract_methods_not_implemented(self): | ||
"""Test that abstract methods raise NotImplementedError when not implemented.""" | ||
client = Mock() | ||
|
||
class IncompleteHandler(BaseTaskHandler[str]): | ||
async def _handle_task_implementation(self, task: str) -> None: | ||
raise NotImplementedError() | ||
|
||
async def handle_task_failure(self, task: str, error: Exception) -> None: | ||
raise NotImplementedError() | ||
|
||
handler = IncompleteHandler(client, "test_task_list", "test_identity") | ||
|
||
with pytest.raises(NotImplementedError): | ||
await handler._handle_task_implementation("test") | ||
|
||
with pytest.raises(NotImplementedError): | ||
await handler.handle_task_failure("test", Exception("test")) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_generic_type_parameter(self): | ||
"""Test that the generic type parameter works correctly.""" | ||
client = Mock() | ||
|
||
class IntHandler(BaseTaskHandler[int]): | ||
async def _handle_task_implementation(self, task: int) -> None: | ||
pass | ||
|
||
async def handle_task_failure(self, task: int, error: Exception) -> None: | ||
pass | ||
|
||
handler = IntHandler(client, "test_task_list", "test_identity") | ||
|
||
# Should accept int tasks | ||
await handler.handle_task(42) | ||
|
||
# Type checker should catch type mismatches (this is more of a static analysis test) | ||
# In runtime, Python won't enforce the type, but the type hints are there for static analysis |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just don't handle exception at all and let worker to handle it