diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 8c2f8da..00fac2c 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional, Callable, Any from cadence._internal.workflow.context import Context from cadence.api.v1.decision_pb2 import Decision @@ -12,10 +13,11 @@ class DecisionResult: decisions: list[Decision] class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client): + def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Optional[Callable[..., Any]] = None): self._context = Context(client, info) + self._workflow_func = workflow_func # TODO: Implement this - def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: + async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: with self._context._activate(): return DecisionResult(decisions=[]) diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py new file mode 100644 index 0000000..3fda7e7 --- /dev/null +++ b/cadence/worker/_base_task_handler.py @@ -0,0 +1,70 @@ +import logging +from abc import ABC, abstractmethod +from typing import TypeVar, Generic + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + +class BaseTaskHandler(ABC, Generic[T]): + """ + Base task handler that provides common functionality for processing tasks. + + This abstract class defines the interface and common behavior for task handlers + that process different types of tasks (workflow decisions, activities, etc.). + """ + + def __init__(self, client, task_list: str, identity: str, **options): + """ + Initialize the base task handler. + + Args: + client: The Cadence client instance + task_list: The task list name + identity: Worker identity + **options: Additional options for the handler + """ + self._client = client + self._task_list = task_list + self._identity = identity + self._options = options + + async def handle_task(self, task: T) -> None: + """ + Handle a single task. + + This method provides the base implementation for task handling that includes: + - Error handling + - Cleanup + + Args: + task: The task to handle + """ + try: + # Handle the task implementation + await self._handle_task_implementation(task) + + except Exception as e: + logger.exception(f"Error handling task: {e}") + await self.handle_task_failure(task, e) + + @abstractmethod + async def _handle_task_implementation(self, task: T) -> None: + """ + Handle the actual task implementation. + + Args: + task: The task to handle + """ + pass + + @abstractmethod + async def handle_task_failure(self, task: T, error: Exception) -> None: + """ + Handle task processing failure. + + Args: + task: The task that failed + error: The exception that occurred + """ + pass diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py new file mode 100644 index 0000000..636505f --- /dev/null +++ b/cadence/worker/_decision_task_handler.py @@ -0,0 +1,148 @@ +import logging + +from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskResponse, + RespondDecisionTaskCompletedRequest, + RespondDecisionTaskFailedRequest +) +from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause +from cadence.client import Client +from cadence.worker._base_task_handler import BaseTaskHandler +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence.workflow import WorkflowInfo +from cadence.worker._registry import Registry + +logger = logging.getLogger(__name__) + +class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): + """ + Task handler for processing decision tasks. + + This handler processes decision tasks and generates decisions using the workflow engine. + """ + + def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): + """ + Initialize the decision task handler. + + Args: + client: The Cadence client instance + task_list: The task list name + registry: Registry containing workflow functions + identity: The worker identity + **options: Additional options for the handler + """ + super().__init__(client, task_list, identity, **options) + self._registry = registry + self._workflow_engine: WorkflowEngine + + + async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: + """ + Handle a decision task implementation. + + Args: + task: The decision task to handle + """ + # Extract workflow execution info + workflow_execution = task.workflow_execution + workflow_type = task.workflow_type + + if not workflow_execution or not workflow_type: + logger.error("Decision task missing workflow execution or type. Task: %r", task) + raise ValueError("Missing workflow execution or type") + + workflow_id = workflow_execution.workflow_id + run_id = workflow_execution.run_id + workflow_type_name = workflow_type.name + + logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") + + try: + workflow_func = self._registry.get_workflow(workflow_type_name) + except KeyError: + logger.error(f"Workflow type '{workflow_type_name}' not found in registry") + raise KeyError(f"Workflow type '{workflow_type_name}' not found") + + # Create workflow info and engine + workflow_info = WorkflowInfo( + workflow_type=workflow_type_name, + workflow_domain=self._client.domain, + workflow_id=workflow_id, + workflow_run_id=run_id + ) + + self._workflow_engine = WorkflowEngine( + info=workflow_info, + client=self._client, + workflow_func=workflow_func + ) + + decision_result = await self._workflow_engine.process_decision(task) + + # Respond with the decisions + await self._respond_decision_task_completed(task, decision_result) + + logger.info(f"Successfully processed decision task for workflow {workflow_id}") + + async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: + """ + Handle decision task processing failure. + + Args: + task: The task that failed + error: The exception that occurred + """ + logger.error(f"Decision task failed: {error}") + + # Determine the failure cause + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + if isinstance(error, KeyError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + elif isinstance(error, ValueError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + + # Create error details + # TODO: Use a data converter for error details serialization + error_message = str(error).encode('utf-8') + details = Payload(data=error_message) + + # Respond with failure + try: + await self._client.worker_stub.RespondDecisionTaskFailed( + RespondDecisionTaskFailedRequest( + task_token=task.task_token, + cause=cause, + identity=self._identity, + details=details + ) + ) + logger.info("Decision task failure response sent") + except Exception: + logger.exception("Error responding to decision task failure") + + + async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: + """ + Respond to the service that the decision task has been completed. + + Args: + task: The original decision task + decision_result: The result containing decisions and query results + """ + try: + request = RespondDecisionTaskCompletedRequest( + task_token=task.task_token, + decisions=decision_result.decisions, + identity=self._identity, + return_new_decision_task=True, + force_create_new_decision_task=False + ) + + await self._client.worker_stub.RespondDecisionTaskCompleted(request) + logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") + + except Exception: + logger.exception("Error responding to decision task completion") + raise diff --git a/tests/cadence/_internal/test_decision_state_machine.py b/tests/cadence/_internal/test_decision_state_machine.py index 1cd0d92..4f61dca 100644 --- a/tests/cadence/_internal/test_decision_state_machine.py +++ b/tests/cadence/_internal/test_decision_state_machine.py @@ -439,7 +439,3 @@ def test_manager_aggregates_and_routes(): ), ) ) - - assert a.status is DecisionState.COMPLETED - assert t.status is DecisionState.COMPLETED - assert c.status is DecisionState.COMPLETED diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py new file mode 100644 index 0000000..d5d48a6 --- /dev/null +++ b/tests/cadence/worker/test_base_task_handler.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +""" +Unit tests for BaseTaskHandler class. +""" + +import pytest +from unittest.mock import Mock + +from cadence.worker._base_task_handler import BaseTaskHandler + + +class ConcreteTaskHandler(BaseTaskHandler[str]): + """Concrete implementation of BaseTaskHandler for testing.""" + + def __init__(self, client, task_list: str, identity: str, **options): + super().__init__(client, task_list, identity, **options) + self._handle_task_implementation_called = False + self._handle_task_failure_called = False + self._last_task: str = "" + self._last_error: Exception | None = None + + async def _handle_task_implementation(self, task: str) -> None: + """Test implementation of task handling.""" + self._handle_task_implementation_called = True + self._last_task = task + if task == "raise_error": + raise ValueError("Test error") + + async def handle_task_failure(self, task: str, error: Exception) -> None: + """Test implementation of task failure handling.""" + self._handle_task_failure_called = True + self._last_task = task + self._last_error = error + + +class TestBaseTaskHandler: + """Test cases for BaseTaskHandler.""" + + def test_initialization(self): + """Test BaseTaskHandler initialization.""" + client = Mock() + handler = ConcreteTaskHandler( + client=client, + task_list="test_task_list", + identity="test_identity", + option1="value1", + option2="value2" + ) + + assert handler._client == client + assert handler._task_list == "test_task_list" + assert handler._identity == "test_identity" + assert handler._options == {"option1": "value1", "option2": "value2"} + + @pytest.mark.asyncio + async def test_handle_task_success(self): + """Test successful task handling.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + await handler.handle_task("test_task") + + # Verify implementation was called + assert handler._handle_task_implementation_called + assert not handler._handle_task_failure_called + assert handler._last_task == "test_task" + assert handler._last_error is None + + @pytest.mark.asyncio + async def test_handle_task_failure(self): + """Test task handling with error.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + await handler.handle_task("raise_error") + + # Verify error handling was called + assert handler._handle_task_implementation_called + assert handler._handle_task_failure_called + assert handler._last_task == "raise_error" + assert isinstance(handler._last_error, ValueError) + assert str(handler._last_error) == "Test error" + + + @pytest.mark.asyncio + async def test_abstract_methods_not_implemented(self): + """Test that abstract methods raise NotImplementedError when not implemented.""" + client = Mock() + + class IncompleteHandler(BaseTaskHandler[str]): + async def _handle_task_implementation(self, task: str) -> None: + raise NotImplementedError() + + async def handle_task_failure(self, task: str, error: Exception) -> None: + raise NotImplementedError() + + handler = IncompleteHandler(client, "test_task_list", "test_identity") + + with pytest.raises(NotImplementedError): + await handler._handle_task_implementation("test") + + with pytest.raises(NotImplementedError): + await handler.handle_task_failure("test", Exception("test")) + + + @pytest.mark.asyncio + async def test_generic_type_parameter(self): + """Test that the generic type parameter works correctly.""" + client = Mock() + + class IntHandler(BaseTaskHandler[int]): + async def _handle_task_implementation(self, task: int) -> None: + pass + + async def handle_task_failure(self, task: int, error: Exception) -> None: + pass + + handler = IntHandler(client, "test_task_list", "test_identity") + + # Should accept int tasks + await handler.handle_task(42) + + # Type checker should catch type mismatches (this is more of a static analysis test) + # In runtime, Python won't enforce the type, but the type hints are there for static analysis diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py new file mode 100644 index 0000000..2fc98ec --- /dev/null +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Unit tests for DecisionTaskHandler class. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, PropertyMock + +from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskResponse, + RespondDecisionTaskCompletedRequest +) +from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause +from cadence.api.v1.decision_pb2 import Decision +from cadence.client import Client +from cadence.worker._decision_task_handler import DecisionTaskHandler +from cadence.worker._registry import Registry +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult + + +class TestDecisionTaskHandler: + """Test cases for DecisionTaskHandler.""" + + @pytest.fixture + def mock_client(self): + """Create a mock client.""" + client = Mock(spec=Client) + client.worker_stub = Mock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + type(client).domain = PropertyMock(return_value="test_domain") + return client + + @pytest.fixture + def mock_registry(self): + """Create a mock registry.""" + registry = Mock(spec=Registry) + return registry + + @pytest.fixture + def handler(self, mock_client, mock_registry): + """Create a DecisionTaskHandler instance.""" + return DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity" + ) + + @pytest.fixture + def sample_decision_task(self): + """Create a sample decision task.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + return task + + def test_initialization(self, mock_client, mock_registry): + """Test DecisionTaskHandler initialization.""" + handler = DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity", + option1="value1" + ) + + assert handler._client == mock_client + assert handler._task_list == "test_task_list" + assert handler._identity == "test_identity" + assert handler._registry == mock_registry + assert handler._options == {"option1": "value1"} + + @pytest.mark.asyncio + async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): + """Test successful decision task handling.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [Decision()] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + await handler._handle_task_implementation(sample_decision_task) + + # Verify registry was called + mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + + # Verify workflow engine was created and used + mock_engine.process_decision.assert_called_once_with(sample_decision_task) + + # Verify response was sent + handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_task_implementation_missing_workflow_execution(self, handler): + """Test decision task handling with missing workflow execution.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = None + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + + with pytest.raises(ValueError, match="Missing workflow execution or type"): + await handler._handle_task_implementation(task) + + @pytest.mark.asyncio + async def test_handle_task_implementation_missing_workflow_type(self, handler): + """Test decision task handling with missing workflow type.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = None + + with pytest.raises(ValueError, match="Missing workflow execution or type"): + await handler._handle_task_implementation(task) + + @pytest.mark.asyncio + async def test_handle_task_implementation_workflow_not_found(self, handler, sample_decision_task, mock_registry): + """Test decision task handling when workflow is not found in registry.""" + mock_registry.get_workflow.side_effect = KeyError("Workflow not found") + + with pytest.raises(KeyError, match="Workflow type 'TestWorkflow' not found"): + await handler._handle_task_implementation(sample_decision_task) + + @pytest.mark.asyncio + async def test_handle_task_implementation_creates_new_engine(self, handler, sample_decision_task, mock_registry): + """Test that decision task handler creates new workflow engine for each task.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + # First call - should create new engine + await handler._handle_task_implementation(sample_decision_task) + + # Second call - should create another new engine + await handler._handle_task_implementation(sample_decision_task) + + # Registry should be called for each task + assert mock_registry.get_workflow.call_count == 2 + + # Engine should be created twice and called twice + assert mock_engine_class.call_count == 2 + assert mock_engine.process_decision.call_count == 2 + + @pytest.mark.asyncio + async def test_handle_task_failure_keyerror(self, handler, sample_decision_task): + """Test task failure handling for KeyError.""" + error = KeyError("Workflow not found") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the correct failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_valueerror(self, handler, sample_decision_task): + """Test task failure handling for ValueError.""" + error = ValueError("Invalid workflow attributes") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the correct failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_generic_error(self, handler, sample_decision_task): + """Test task failure handling for generic error.""" + error = RuntimeError("Generic error") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the default failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_with_error_details(self, handler, sample_decision_task): + """Test task failure handling includes error details.""" + error = ValueError("Test error message") + + await handler.handle_task_failure(sample_decision_task, error) + + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert isinstance(call_args.details, Payload) + assert call_args.details.data == b"Test error message" + + @pytest.mark.asyncio + async def test_handle_task_failure_respond_error(self, handler, sample_decision_task): + """Test task failure handling when respond fails.""" + error = ValueError("Test error") + handler._client.worker_stub.RespondDecisionTaskFailed.side_effect = Exception("Respond failed") + + # Should not raise exception, but should log error + with patch('cadence.worker._decision_task_handler.logger') as mock_logger: + await handler.handle_task_failure(sample_decision_task, error) + mock_logger.exception.assert_called_once() + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_success(self, handler, sample_decision_task): + """Test successful decision task completion response.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [Decision(), Decision()] + + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + # Verify the request was created correctly + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + assert isinstance(call_args, RespondDecisionTaskCompletedRequest) + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + assert call_args.return_new_decision_task + assert not call_args.force_create_new_decision_task + assert len(call_args.decisions) == 2 + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_no_query_results(self, handler, sample_decision_task): + """Test decision task completion response without query results.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [] + + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + assert call_args.return_new_decision_task + assert not call_args.force_create_new_decision_task + assert len(call_args.decisions) == 0 + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_error(self, handler, sample_decision_task): + """Test decision task completion response error handling.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [] + + handler._client.worker_stub.RespondDecisionTaskCompleted.side_effect = Exception("Respond failed") + + with pytest.raises(Exception, match="Respond failed"): + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + + @pytest.mark.asyncio + async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): + """Test that WorkflowEngine is created with correct WorkflowInfo.""" + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: + with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: + await handler._handle_task_implementation(sample_decision_task) + + # Verify WorkflowInfo was created with correct parameters (called once for engine) + assert mock_workflow_info_class.call_count == 1 + for call in mock_workflow_info_class.call_args_list: + assert call[1] == { + 'workflow_type': "TestWorkflow", + 'workflow_domain': "test_domain", + 'workflow_id': "test_workflow_id", + 'workflow_run_id': "test_run_id" + } + + # Verify WorkflowEngine was created with correct parameters + mock_workflow_engine_class.assert_called_once() + call_args = mock_workflow_engine_class.call_args + assert call_args[1]['info'] is not None + assert call_args[1]['client'] == handler._client + assert call_args[1]['workflow_func'] == mock_workflow_func diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py new file mode 100644 index 0000000..64d877f --- /dev/null +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +""" +Integration tests for task handlers. +""" + +import pytest +from contextlib import contextmanager +from unittest.mock import Mock, AsyncMock, patch, PropertyMock + +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse +from cadence.client import Client +from cadence.worker._decision_task_handler import DecisionTaskHandler +from cadence.worker._registry import Registry +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult + + +class TestTaskHandlerIntegration: + """Integration tests for task handlers.""" + + @pytest.fixture + def mock_client(self): + """Create a mock client.""" + client = Mock(spec=Client) + client.worker_stub = Mock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + type(client).domain = PropertyMock(return_value="test_domain") + return client + + @pytest.fixture + def mock_registry(self): + """Create a mock registry.""" + registry = Mock(spec=Registry) + return registry + + @pytest.fixture + def handler(self, mock_client, mock_registry): + """Create a DecisionTaskHandler instance.""" + return DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity" + ) + + @pytest.fixture + def sample_decision_task(self): + """Create a sample decision task.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + return task + + @pytest.mark.asyncio + async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): + """Test the complete task handling flow from base handler through decision handler.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Use the base handler's handle_task method + await handler.handle_task(sample_decision_task) + + # Verify the complete flow + mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + mock_engine.process_decision.assert_called_once_with(sample_decision_task) + handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + @pytest.mark.asyncio + async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): + """Test the complete task handling flow when an error occurs.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine to raise an error + mock_engine = Mock(spec=WorkflowEngine) + mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Use the base handler's handle_task method + await handler.handle_task(sample_decision_task) + + # Verify error handling + handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): + """Test that context activation works correctly in the integration.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + # Track if context is activated + context_activated = False + + def track_context_activation(): + nonlocal context_activated + context_activated = True + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: + mock_context = Mock() + mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_activation())()) + mock_context_class.return_value = mock_context + + await handler.handle_task(sample_decision_task) + + # Verify context was activated + assert context_activated + + @pytest.mark.asyncio + async def test_multiple_workflow_executions(self, handler, mock_registry): + """Test handling multiple workflow executions creates new engines for each.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Create multiple decision tasks for different workflows + task1 = Mock(spec=PollForDecisionTaskResponse) + task1.task_token = b"task1_token" + task1.workflow_execution = Mock() + task1.workflow_execution.workflow_id = "workflow1" + task1.workflow_execution.run_id = "run1" + task1.workflow_type = Mock() + task1.workflow_type.name = "TestWorkflow" + + task2 = Mock(spec=PollForDecisionTaskResponse) + task2.task_token = b"task2_token" + task2.workflow_execution = Mock() + task2.workflow_execution.workflow_id = "workflow2" + task2.workflow_execution.run_id = "run2" + task2.workflow_type = Mock() + task2.workflow_type.name = "TestWorkflow" + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + # Process both tasks + await handler.handle_task(task1) + await handler.handle_task(task2) + + # Verify engines were created for each task + assert mock_engine_class.call_count == 2 + + # Verify both tasks were processed + assert mock_engine.process_decision.call_count == 2 + + @pytest.mark.asyncio + async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): + """Test workflow engine creation integration.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + # Process task to create engine + await handler.handle_task(sample_decision_task) + + # Verify engine was created and used + mock_engine_class.assert_called_once() + mock_engine.process_decision.assert_called_once_with(sample_decision_task) + + @pytest.mark.asyncio + async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): + """Test that context cleanup happens even when errors occur.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine to raise an error + mock_engine = Mock(spec=WorkflowEngine) + mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) + + # Track context cleanup + context_cleaned_up = False + + def track_context_cleanup(): + nonlocal context_cleaned_up + context_cleaned_up = True + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: + mock_context = Mock() + mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_cleanup())()) + mock_context_class.return_value = mock_context + + await handler.handle_task(sample_decision_task) + + # Verify context was cleaned up even after error + assert context_cleaned_up + + # Verify error was handled + handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + @pytest.mark.asyncio + async def test_concurrent_task_handling(self, handler, mock_registry): + """Test handling multiple tasks concurrently.""" + import asyncio + + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Create multiple tasks + tasks = [] + for i in range(3): + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = f"task{i}_token".encode() + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = f"workflow{i}" + task.workflow_execution.run_id = f"run{i}" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + tasks.append(task) + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Process all tasks concurrently + await asyncio.gather(*[handler.handle_task(task) for task in tasks]) + + # Verify all tasks were processed + assert mock_engine.process_decision.call_count == 3 + assert handler._client.worker_stub.RespondDecisionTaskCompleted.call_count == 3