-
Notifications
You must be signed in to change notification settings - Fork 3
Add base and decision task handler #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
3fd5244
7818de9
5d759c3
75b8664
ea56208
08b0e3b
e269704
918d128
9f9c289
7d6eaec
27277e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import TypeVar, Generic | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
T = TypeVar('T') | ||
|
||
class BaseTaskHandler(ABC, Generic[T]): | ||
""" | ||
Base task handler that provides common functionality for processing tasks. | ||
|
||
This abstract class defines the interface and common behavior for task handlers | ||
that process different types of tasks (workflow decisions, activities, etc.). | ||
""" | ||
|
||
def __init__(self, client, task_list: str, identity: str, **options): | ||
""" | ||
Initialize the base task handler. | ||
|
||
Args: | ||
client: The Cadence client instance | ||
task_list: The task list name | ||
identity: Worker identity | ||
**options: Additional options for the handler | ||
""" | ||
self._client = client | ||
self._task_list = task_list | ||
self._identity = identity | ||
self._options = options | ||
|
||
async def handle_task(self, task: T) -> None: | ||
""" | ||
Handle a single task. | ||
|
||
This method provides the base implementation for task handling that includes: | ||
- Error handling | ||
- Cleanup | ||
|
||
Args: | ||
task: The task to handle | ||
""" | ||
try: | ||
# Handle the task implementation | ||
await self._handle_task_implementation(task) | ||
|
||
except Exception as e: | ||
logger.exception(f"Error handling task: {e}") | ||
await self.handle_task_failure(task, e) | ||
|
||
@abstractmethod | ||
async def _handle_task_implementation(self, task: T) -> None: | ||
""" | ||
Handle the actual task implementation. | ||
|
||
Args: | ||
task: The task to handle | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
async def handle_task_failure(self, task: T, error: Exception) -> None: | ||
""" | ||
Handle task processing failure. | ||
|
||
Args: | ||
task: The task that failed | ||
error: The exception that occurred | ||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import logging | ||
from typing import Dict | ||
|
||
from cadence.api.v1.common_pb2 import Payload | ||
from cadence.api.v1.service_worker_pb2 import ( | ||
PollForDecisionTaskResponse, | ||
RespondDecisionTaskCompletedRequest, | ||
RespondDecisionTaskFailedRequest | ||
) | ||
from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause | ||
from cadence.client import Client | ||
from cadence.worker._base_task_handler import BaseTaskHandler | ||
from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult | ||
from cadence.workflow import WorkflowInfo | ||
from cadence.worker._registry import Registry | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): | ||
""" | ||
Task handler for processing decision tasks. | ||
This handler processes decision tasks and generates decisions using the workflow engine. | ||
""" | ||
|
||
def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): | ||
""" | ||
Initialize the decision task handler. | ||
Args: | ||
client: The Cadence client instance | ||
task_list: The task list name | ||
registry: Registry containing workflow functions | ||
identity: The worker identity | ||
**options: Additional options for the handler | ||
""" | ||
super().__init__(client, task_list, identity, **options) | ||
self._registry = registry | ||
self._workflow_engines: Dict[str, WorkflowEngine] = {} | ||
|
||
|
||
async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: | ||
""" | ||
Handle a decision task implementation. | ||
Args: | ||
task: The decision task to handle | ||
""" | ||
# Extract workflow execution info | ||
workflow_execution = task.workflow_execution | ||
workflow_type = task.workflow_type | ||
|
||
if not workflow_execution or not workflow_type: | ||
logger.error("Decision task missing workflow execution or type") | ||
|
||
raise ValueError("Missing workflow execution or type") | ||
|
||
workflow_id = workflow_execution.workflow_id | ||
run_id = workflow_execution.run_id | ||
workflow_type_name = workflow_type.name | ||
|
||
logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") | ||
|
||
# Get or create workflow engine for this workflow execution | ||
engine_key = f"{workflow_id}:{run_id}" | ||
if engine_key not in self._workflow_engines: | ||
|
||
# Get the workflow function from registry | ||
try: | ||
workflow_func = self._registry.get_workflow(workflow_type_name) | ||
except KeyError: | ||
logger.error(f"Workflow type '{workflow_type_name}' not found in registry") | ||
raise KeyError(f"Workflow type '{workflow_type_name}' not found") | ||
|
||
# Create workflow info and engine | ||
workflow_info = WorkflowInfo( | ||
workflow_type=workflow_type_name, | ||
workflow_domain=self._client.domain, | ||
workflow_id=workflow_id, | ||
workflow_run_id=run_id | ||
) | ||
|
||
self._workflow_engines[engine_key] = WorkflowEngine( | ||
info=workflow_info, | ||
client=self._client, | ||
workflow_func=workflow_func | ||
) | ||
|
||
# Create workflow context and execute with it active | ||
workflow_engine = self._workflow_engines[engine_key] | ||
|
||
decision_result = await workflow_engine.process_decision(task) | ||
|
||
# Respond with the decisions | ||
await self._respond_decision_task_completed(task, decision_result) | ||
|
||
logger.info(f"Successfully processed decision task for workflow {workflow_id}") | ||
|
||
async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: | ||
""" | ||
Handle decision task processing failure. | ||
Args: | ||
task: The task that failed | ||
error: The exception that occurred | ||
""" | ||
try: | ||
logger.error(f"Decision task failed: {error}") | ||
|
||
# Determine the failure cause | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION | ||
if isinstance(error, KeyError): | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE | ||
elif isinstance(error, ValueError): | ||
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES | ||
|
||
# Create error details | ||
error_message = str(error).encode('utf-8') | ||
|
||
details = Payload(data=error_message) | ||
|
||
# Respond with failure | ||
await self._client.worker_stub.RespondDecisionTaskFailed( | ||
RespondDecisionTaskFailedRequest( | ||
task_token=task.task_token, | ||
cause=cause, | ||
identity=self._identity, | ||
details=details | ||
) | ||
) | ||
|
||
logger.info("Decision task failure response sent") | ||
|
||
except Exception: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just don't handle exception at all and let worker to handle it |
||
logger.exception("Error handling decision task failure") | ||
|
||
async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: | ||
""" | ||
Respond to the service that the decision task has been completed. | ||
Args: | ||
task: The original decision task | ||
decision_result: The result containing decisions and query results | ||
""" | ||
try: | ||
request = RespondDecisionTaskCompletedRequest( | ||
task_token=task.task_token, | ||
decisions=decision_result.decisions, | ||
identity=self._identity, | ||
return_new_decision_task=decision_result.force_create_new_decision_task, | ||
force_create_new_decision_task=decision_result.force_create_new_decision_task | ||
|
||
) | ||
|
||
# Add query results if present | ||
if decision_result.query_results: | ||
request.query_results.update(decision_result.query_results) | ||
|
||
await self._client.worker_stub.RespondDecisionTaskCompleted(request) | ||
logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") | ||
|
||
except Exception: | ||
logger.exception("Error responding to decision task completion") | ||
raise | ||
|
||
def cleanup_workflow_engine(self, workflow_id: str, run_id: str) -> None: | ||
|
||
""" | ||
Clean up a workflow engine when workflow execution is complete. | ||
Args: | ||
workflow_id: The workflow ID | ||
run_id: The run ID | ||
""" | ||
engine_key = f"{workflow_id}:{run_id}" | ||
if engine_key in self._workflow_engines: | ||
del self._workflow_engines[engine_key] | ||
logger.debug(f"Cleaned up workflow engine for {workflow_id}:{run_id}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Unit tests for BaseTaskHandler class. | ||
""" | ||
|
||
import pytest | ||
from unittest.mock import Mock | ||
|
||
from cadence.worker._base_task_handler import BaseTaskHandler | ||
|
||
|
||
class ConcreteTaskHandler(BaseTaskHandler[str]): | ||
"""Concrete implementation of BaseTaskHandler for testing.""" | ||
|
||
def __init__(self, client, task_list: str, identity: str, **options): | ||
super().__init__(client, task_list, identity, **options) | ||
self._handle_task_implementation_called = False | ||
self._handle_task_failure_called = False | ||
self._last_task: str = "" | ||
self._last_error: Exception | None = None | ||
|
||
async def _handle_task_implementation(self, task: str) -> None: | ||
"""Test implementation of task handling.""" | ||
self._handle_task_implementation_called = True | ||
self._last_task = task | ||
if task == "raise_error": | ||
raise ValueError("Test error") | ||
|
||
async def handle_task_failure(self, task: str, error: Exception) -> None: | ||
"""Test implementation of task failure handling.""" | ||
self._handle_task_failure_called = True | ||
self._last_task = task | ||
self._last_error = error | ||
|
||
|
||
class TestBaseTaskHandler: | ||
"""Test cases for BaseTaskHandler.""" | ||
|
||
def test_initialization(self): | ||
"""Test BaseTaskHandler initialization.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler( | ||
client=client, | ||
task_list="test_task_list", | ||
identity="test_identity", | ||
option1="value1", | ||
option2="value2" | ||
) | ||
|
||
assert handler._client == client | ||
assert handler._task_list == "test_task_list" | ||
assert handler._identity == "test_identity" | ||
assert handler._options == {"option1": "value1", "option2": "value2"} | ||
|
||
@pytest.mark.asyncio | ||
async def test_handle_task_success(self): | ||
"""Test successful task handling.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") | ||
|
||
await handler.handle_task("test_task") | ||
|
||
# Verify implementation was called | ||
assert handler._handle_task_implementation_called | ||
assert not handler._handle_task_failure_called | ||
assert handler._last_task == "test_task" | ||
assert handler._last_error is None | ||
|
||
@pytest.mark.asyncio | ||
async def test_handle_task_failure(self): | ||
"""Test task handling with error.""" | ||
client = Mock() | ||
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") | ||
|
||
await handler.handle_task("raise_error") | ||
|
||
# Verify error handling was called | ||
assert handler._handle_task_implementation_called | ||
assert handler._handle_task_failure_called | ||
assert handler._last_task == "raise_error" | ||
assert isinstance(handler._last_error, ValueError) | ||
assert str(handler._last_error) == "Test error" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_abstract_methods_not_implemented(self): | ||
"""Test that abstract methods raise NotImplementedError when not implemented.""" | ||
client = Mock() | ||
|
||
class IncompleteHandler(BaseTaskHandler[str]): | ||
async def _handle_task_implementation(self, task: str) -> None: | ||
raise NotImplementedError() | ||
|
||
async def handle_task_failure(self, task: str, error: Exception) -> None: | ||
raise NotImplementedError() | ||
|
||
handler = IncompleteHandler(client, "test_task_list", "test_identity") | ||
|
||
with pytest.raises(NotImplementedError): | ||
await handler._handle_task_implementation("test") | ||
|
||
with pytest.raises(NotImplementedError): | ||
await handler.handle_task_failure("test", Exception("test")) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_generic_type_parameter(self): | ||
"""Test that the generic type parameter works correctly.""" | ||
client = Mock() | ||
|
||
class IntHandler(BaseTaskHandler[int]): | ||
async def _handle_task_implementation(self, task: int) -> None: | ||
pass | ||
|
||
async def handle_task_failure(self, task: int, error: Exception) -> None: | ||
pass | ||
|
||
handler = IntHandler(client, "test_task_list", "test_identity") | ||
|
||
# Should accept int tasks | ||
await handler.handle_task(42) | ||
|
||
# Type checker should catch type mismatches (this is more of a static analysis test) | ||
# In runtime, Python won't enforce the type, but the type hints are there for static analysis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed at the moment. This is related to local activities and workflow query. We might have better solutions later than copying java legacy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those fields are used in RespondDecisionTaskCompletedRequest but we can remove them for now