From 3fd5244bf3bf66d62f4e8c4c2b7801f6d81bd2e3 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 12 Sep 2025 17:03:08 -0700 Subject: [PATCH 01/10] add base and decision task handler Signed-off-by: Tim Li --- cadence/worker/_base_task_handler.py | 94 ++++++++++++ cadence/worker/_decision_task_handler.py | 174 +++++++++++++++++++++++ 2 files changed, 268 insertions(+) create mode 100644 cadence/worker/_base_task_handler.py create mode 100644 cadence/worker/_decision_task_handler.py diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py new file mode 100644 index 0000000..751b2bb --- /dev/null +++ b/cadence/worker/_base_task_handler.py @@ -0,0 +1,94 @@ +import logging +from abc import ABC, abstractmethod +from typing import TypeVar, Generic + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + +class BaseTaskHandler(ABC, Generic[T]): + """ + Base task handler that provides common functionality for processing tasks. + + This abstract class defines the interface and common behavior for task handlers + that process different types of tasks (workflow decisions, activities, etc.). + """ + + def __init__(self, client, task_list: str, identity: str, **options): + """ + Initialize the base task handler. + + Args: + client: The Cadence client instance + task_list: The task list name + identity: Worker identity + **options: Additional options for the handler + """ + self._client = client + self._task_list = task_list + self._identity = identity + self._options = options + + async def handle_task(self, task: T) -> None: + """ + Handle a single task. + + This method provides the base implementation for task handling that includes: + - Context propagation + - Error handling + - Cleanup + + Args: + task: The task to handle + """ + try: + # Propagate context from task parameters + await self._propagate_context(task) + + # Handle the task + await self._handle_task_implementation(task) + + except Exception as e: + logger.exception(f"Error handling task: {e}") + await self.handle_task_failure(task, e) + finally: + # Clean up context + await self._unset_current_context() + + @abstractmethod + async def _handle_task_implementation(self, task: T) -> None: + """ + Handle the actual task implementation. + + Args: + task: The task to handle + """ + pass + + @abstractmethod + async def handle_task_failure(self, task: T, error: Exception) -> None: + """ + Handle task processing failure. + + Args: + task: The task that failed + error: The exception that occurred + """ + pass + + async def _propagate_context(self, task: T) -> None: + """ + Propagate context from task parameters. + + Args: + task: The task containing context information + """ + # Default implementation - subclasses should override if needed + pass + + async def _unset_current_context(self) -> None: + """ + Unset the current context after task completion. + """ + # Default implementation - subclasses should override if needed + pass diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py new file mode 100644 index 0000000..4ebc719 --- /dev/null +++ b/cadence/worker/_decision_task_handler.py @@ -0,0 +1,174 @@ +import logging +from typing import Dict, Any + +from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskResponse, + RespondDecisionTaskCompletedRequest, + RespondDecisionTaskFailedRequest +) +from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause +from cadence.client import Client +from cadence.worker._base_task_handler import BaseTaskHandler +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence.workflow import WorkflowInfo +from cadence.worker._registry import Registry + +logger = logging.getLogger(__name__) + +class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): + """ + Task handler for processing decision tasks. + + This handler processes decision tasks and generates decisions using the workflow engine. + """ + + def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): + """ + Initialize the decision task handler. + + Args: + client: The Cadence client instance + task_list: The task list name + registry: Registry containing workflow functions + identity: The worker identity + **options: Additional options for the handler + """ + super().__init__(client, task_list, identity, **options) + self._registry = registry + self._workflow_engines: Dict[str, WorkflowEngine] = {} + + + async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: + """ + Handle a decision task implementation. + + Args: + task: The decision task to handle + """ + # Extract workflow execution info + workflow_execution = task.workflow_execution + workflow_type = task.workflow_type + + if not workflow_execution or not workflow_type: + logger.error("Decision task missing workflow execution or type") + await self.handle_task_failure(task, ValueError("Missing workflow execution or type")) + return + + workflow_id = workflow_execution.workflow_id + run_id = workflow_execution.run_id + workflow_type_name = workflow_type.name + + logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") + + # Get or create workflow engine for this workflow execution + engine_key = f"{workflow_id}:{run_id}" + if engine_key not in self._workflow_engines: + # Get the workflow function from registry + try: + workflow_func = self._registry.get_workflow(workflow_type_name) + except KeyError: + logger.error(f"Workflow type '{workflow_type_name}' not found in registry") + await self.handle_task_failure(task, KeyError(f"Workflow type '{workflow_type_name}' not found")) + return + + # Create workflow info and engine + workflow_info = WorkflowInfo( + workflow_type=workflow_type_name, + workflow_domain=self._client.domain, + workflow_id=workflow_id, + workflow_run_id=run_id + ) + + self._workflow_engines[engine_key] = WorkflowEngine( + info=workflow_info, + client=self._client, + workflow_func=workflow_func + ) + + # Process the decision using the workflow engine + workflow_engine = self._workflow_engines[engine_key] + decision_result = await workflow_engine.process_decision(task) + + # Respond with the decisions + await self._respond_decision_task_completed(task, decision_result) + + logger.info(f"Successfully processed decision task for workflow {workflow_id}") + + async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: + """ + Handle decision task processing failure. + + Args: + task: The task that failed + error: The exception that occurred + """ + try: + logger.error(f"Decision task failed: {error}") + + # Determine the failure cause + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + if isinstance(error, KeyError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + elif isinstance(error, ValueError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + + # Create error details + error_message = str(error).encode('utf-8') + details = Payload(data=error_message) + + # Respond with failure + await self._client.worker_stub.RespondDecisionTaskFailed( + RespondDecisionTaskFailedRequest( + task_token=task.task_token, + cause=cause, + identity=self._identity, + details=details + ) + ) + + logger.info("Decision task failure response sent") + + except Exception: + logger.exception("Error handling decision task failure") + + async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: + """ + Respond to the service that the decision task has been completed. + + Args: + task: The original decision task + decision_result: The result containing decisions and query results + """ + try: + request = RespondDecisionTaskCompletedRequest( + task_token=task.task_token, + decisions=decision_result.decisions, + identity=self._identity, + return_new_decision_task=decision_result.force_create_new_decision_task, + force_create_new_decision_task=decision_result.force_create_new_decision_task + ) + + # Add query results if present + if decision_result.query_results: + request.query_results.update(decision_result.query_results) + + await self._client.worker_stub.RespondDecisionTaskCompleted(request) + logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") + + except Exception: + logger.exception("Error responding to decision task completion") + raise + + def cleanup_workflow_engine(self, workflow_id: str, run_id: str) -> None: + """ + Clean up a workflow engine when workflow execution is complete. + + Args: + workflow_id: The workflow ID + run_id: The run ID + """ + engine_key = f"{workflow_id}:{run_id}" + if engine_key in self._workflow_engines: + del self._workflow_engines[engine_key] + logger.debug(f"Cleaned up workflow engine for {workflow_id}:{run_id}") From 7818de95ba408233b5c706c13df506c252a49e53 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 12 Sep 2025 17:16:19 -0700 Subject: [PATCH 02/10] add unit test Signed-off-by: Tim Li --- .../cadence/worker/test_base_task_handler.py | 209 ++++++++++ .../worker/test_decision_task_handler.py | 359 ++++++++++++++++++ .../worker/test_task_handler_integration.py | 293 ++++++++++++++ 3 files changed, 861 insertions(+) create mode 100644 tests/cadence/worker/test_base_task_handler.py create mode 100644 tests/cadence/worker/test_decision_task_handler.py create mode 100644 tests/cadence/worker/test_task_handler_integration.py diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py new file mode 100644 index 0000000..d8b0004 --- /dev/null +++ b/tests/cadence/worker/test_base_task_handler.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +""" +Unit tests for BaseTaskHandler class. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from typing import Any + +from cadence.worker._base_task_handler import BaseTaskHandler + + +class ConcreteTaskHandler(BaseTaskHandler[str]): + """Concrete implementation of BaseTaskHandler for testing.""" + + def __init__(self, client, task_list: str, identity: str, **options): + super().__init__(client, task_list, identity, **options) + self._handle_task_implementation_called = False + self._handle_task_failure_called = False + self._propagate_context_called = False + self._unset_current_context_called = False + self._last_task: str = "" + self._last_error: Exception | None = None + + async def _handle_task_implementation(self, task: str) -> None: + """Test implementation of task handling.""" + self._handle_task_implementation_called = True + self._last_task = task + if task == "raise_error": + raise ValueError("Test error") + + async def handle_task_failure(self, task: str, error: Exception) -> None: + """Test implementation of task failure handling.""" + self._handle_task_failure_called = True + self._last_task = task + self._last_error = error + + async def _propagate_context(self, task: str) -> None: + """Test implementation of context propagation.""" + self._propagate_context_called = True + self._last_task = task + + async def _unset_current_context(self) -> None: + """Test implementation of context cleanup.""" + self._unset_current_context_called = True + + +class TestBaseTaskHandler: + """Test cases for BaseTaskHandler.""" + + def test_initialization(self): + """Test BaseTaskHandler initialization.""" + client = Mock() + handler = ConcreteTaskHandler( + client=client, + task_list="test_task_list", + identity="test_identity", + option1="value1", + option2="value2" + ) + + assert handler._client == client + assert handler._task_list == "test_task_list" + assert handler._identity == "test_identity" + assert handler._options == {"option1": "value1", "option2": "value2"} + + @pytest.mark.asyncio + async def test_handle_task_success(self): + """Test successful task handling.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + await handler.handle_task("test_task") + + # Verify all methods were called in correct order + assert handler._propagate_context_called + assert handler._handle_task_implementation_called + assert handler._unset_current_context_called + assert not handler._handle_task_failure_called + assert handler._last_task == "test_task" + assert handler._last_error is None + + @pytest.mark.asyncio + async def test_handle_task_failure(self): + """Test task handling with error.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + await handler.handle_task("raise_error") + + # Verify error handling was called + assert handler._propagate_context_called + assert handler._handle_task_implementation_called + assert handler._handle_task_failure_called + assert handler._unset_current_context_called + assert handler._last_task == "raise_error" + assert isinstance(handler._last_error, ValueError) + assert str(handler._last_error) == "Test error" + + @pytest.mark.asyncio + async def test_handle_task_with_context_propagation_error(self): + """Test task handling when context propagation fails.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + # Override _propagate_context to raise an error + async def failing_propagate_context(task): + raise RuntimeError("Context propagation failed") + + # Use setattr to avoid mypy error about method assignment + setattr(handler, '_propagate_context', failing_propagate_context) + + await handler.handle_task("test_task") + + # Verify error handling was called + assert handler._handle_task_failure_called + assert handler._unset_current_context_called + assert isinstance(handler._last_error, RuntimeError) + assert str(handler._last_error) == "Context propagation failed" + + @pytest.mark.asyncio + async def test_handle_task_with_cleanup_error(self): + """Test task handling when cleanup fails.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + # Override _unset_current_context to raise an error + async def failing_unset_context(): + raise RuntimeError("Cleanup failed") + + # Use setattr to avoid mypy error about method assignment + setattr(handler, '_unset_current_context', failing_unset_context) + + # Cleanup errors in finally block will propagate + with pytest.raises(RuntimeError, match="Cleanup failed"): + await handler.handle_task("test_task") + + @pytest.mark.asyncio + async def test_handle_task_with_implementation_and_cleanup_errors(self): + """Test task handling when both implementation and cleanup fail.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + # Override _unset_current_context to raise an error + async def failing_unset_context(): + raise RuntimeError("Cleanup failed") + + # Use setattr to avoid mypy error about method assignment + setattr(handler, '_unset_current_context', failing_unset_context) + + # The implementation error should be handled, but cleanup error will propagate + with pytest.raises(RuntimeError, match="Cleanup failed"): + await handler.handle_task("raise_error") + + # Verify the implementation error was handled before cleanup error + assert handler._handle_task_failure_called + assert isinstance(handler._last_error, ValueError) + + @pytest.mark.asyncio + async def test_abstract_methods_not_implemented(self): + """Test that abstract methods raise NotImplementedError when not implemented.""" + client = Mock() + + class IncompleteHandler(BaseTaskHandler[str]): + async def _handle_task_implementation(self, task: str) -> None: + raise NotImplementedError() + + async def handle_task_failure(self, task: str, error: Exception) -> None: + raise NotImplementedError() + + handler = IncompleteHandler(client, "test_task_list", "test_identity") + + with pytest.raises(NotImplementedError): + await handler._handle_task_implementation("test") + + with pytest.raises(NotImplementedError): + await handler.handle_task_failure("test", Exception("test")) + + @pytest.mark.asyncio + async def test_default_context_methods(self): + """Test default implementations of context methods.""" + client = Mock() + handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") + + # Test default _propagate_context (should not raise) + await handler._propagate_context("test_task") + + # Test default _unset_current_context (should not raise) + await handler._unset_current_context() + + @pytest.mark.asyncio + async def test_generic_type_parameter(self): + """Test that the generic type parameter works correctly.""" + client = Mock() + + class IntHandler(BaseTaskHandler[int]): + async def _handle_task_implementation(self, task: int) -> None: + pass + + async def handle_task_failure(self, task: int, error: Exception) -> None: + pass + + handler = IntHandler(client, "test_task_list", "test_identity") + + # Should accept int tasks + await handler.handle_task(42) + + # Type checker should catch type mismatches (this is more of a static analysis test) + # In runtime, Python won't enforce the type, but the type hints are there for static analysis diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py new file mode 100644 index 0000000..96e7846 --- /dev/null +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +""" +Unit tests for DecisionTaskHandler class. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, PropertyMock +from typing import Dict, Any + +from cadence.api.v1.common_pb2 import Payload +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskResponse, + RespondDecisionTaskCompletedRequest, + RespondDecisionTaskFailedRequest +) +from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause +from cadence.api.v1.decision_pb2 import Decision +from cadence.client import Client +from cadence.worker._decision_task_handler import DecisionTaskHandler +from cadence.worker._registry import Registry +from cadence.workflow import WorkflowInfo +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult + + +class TestDecisionTaskHandler: + """Test cases for DecisionTaskHandler.""" + + @pytest.fixture + def mock_client(self): + """Create a mock client.""" + client = Mock(spec=Client) + client.worker_stub = Mock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + type(client).domain = PropertyMock(return_value="test_domain") + return client + + @pytest.fixture + def mock_registry(self): + """Create a mock registry.""" + registry = Mock(spec=Registry) + return registry + + @pytest.fixture + def handler(self, mock_client, mock_registry): + """Create a DecisionTaskHandler instance.""" + return DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity" + ) + + @pytest.fixture + def sample_decision_task(self): + """Create a sample decision task.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + return task + + def test_initialization(self, mock_client, mock_registry): + """Test DecisionTaskHandler initialization.""" + handler = DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity", + option1="value1" + ) + + assert handler._client == mock_client + assert handler._task_list == "test_task_list" + assert handler._identity == "test_identity" + assert handler._registry == mock_registry + assert handler._options == {"option1": "value1"} + assert isinstance(handler._workflow_engines, dict) + assert len(handler._workflow_engines) == 0 + + @pytest.mark.asyncio + async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): + """Test successful decision task handling.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [Decision()] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + await handler._handle_task_implementation(sample_decision_task) + + # Verify registry was called + mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + + # Verify workflow engine was created and used + mock_engine.process_decision.assert_called_once_with(sample_decision_task) + + # Verify response was sent + handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_task_implementation_missing_workflow_execution(self, handler): + """Test decision task handling with missing workflow execution.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = None + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + + with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + await handler._handle_task_implementation(task) + + mock_handle_failure.assert_called_once() + args = mock_handle_failure.call_args[0] + assert args[0] == task + assert isinstance(args[1], ValueError) + assert "Missing workflow execution or type" in str(args[1]) + + @pytest.mark.asyncio + async def test_handle_task_implementation_missing_workflow_type(self, handler): + """Test decision task handling with missing workflow type.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = None + + with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + await handler._handle_task_implementation(task) + + mock_handle_failure.assert_called_once() + args = mock_handle_failure.call_args[0] + assert args[0] == task + assert isinstance(args[1], ValueError) + assert "Missing workflow execution or type" in str(args[1]) + + @pytest.mark.asyncio + async def test_handle_task_implementation_workflow_not_found(self, handler, sample_decision_task, mock_registry): + """Test decision task handling when workflow is not found in registry.""" + mock_registry.get_workflow.side_effect = KeyError("Workflow not found") + + with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + await handler._handle_task_implementation(sample_decision_task) + + mock_handle_failure.assert_called_once() + args = mock_handle_failure.call_args[0] + assert args[0] == sample_decision_task + assert isinstance(args[1], KeyError) + assert "Workflow type 'TestWorkflow' not found" in str(args[1]) + + @pytest.mark.asyncio + async def test_handle_task_implementation_reuses_existing_engine(self, handler, sample_decision_task, mock_registry): + """Test that decision task handler reuses existing workflow engine.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # First call - should create new engine + await handler._handle_task_implementation(sample_decision_task) + + # Second call - should reuse existing engine + await handler._handle_task_implementation(sample_decision_task) + + # Registry should only be called once + mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + + # Engine should be called twice + assert mock_engine.process_decision.call_count == 2 + + # Should have one engine in the cache + assert len(handler._workflow_engines) == 1 + engine_key = "test_workflow_id:test_run_id" + assert engine_key in handler._workflow_engines + + @pytest.mark.asyncio + async def test_handle_task_failure_keyerror(self, handler, sample_decision_task): + """Test task failure handling for KeyError.""" + error = KeyError("Workflow not found") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the correct failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_valueerror(self, handler, sample_decision_task): + """Test task failure handling for ValueError.""" + error = ValueError("Invalid workflow attributes") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the correct failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_generic_error(self, handler, sample_decision_task): + """Test task failure handling for generic error.""" + error = RuntimeError("Generic error") + + await handler.handle_task_failure(sample_decision_task, error) + + # Verify the default failure cause was used + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_handle_task_failure_with_error_details(self, handler, sample_decision_task): + """Test task failure handling includes error details.""" + error = ValueError("Test error message") + + await handler.handle_task_failure(sample_decision_task, error) + + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert isinstance(call_args.details, Payload) + assert call_args.details.data == b"Test error message" + + @pytest.mark.asyncio + async def test_handle_task_failure_respond_error(self, handler, sample_decision_task): + """Test task failure handling when respond fails.""" + error = ValueError("Test error") + handler._client.worker_stub.RespondDecisionTaskFailed.side_effect = Exception("Respond failed") + + # Should not raise exception, but should log error + with patch('cadence.worker._decision_task_handler.logger') as mock_logger: + await handler.handle_task_failure(sample_decision_task, error) + mock_logger.exception.assert_called_once() + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_success(self, handler, sample_decision_task): + """Test successful decision task completion response.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [Decision(), Decision()] + decision_result.force_create_new_decision_task = True + decision_result.query_results = None # Test without query results first + + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + # Verify the request was created correctly + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + assert isinstance(call_args, RespondDecisionTaskCompletedRequest) + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + assert call_args.return_new_decision_task == True + assert call_args.force_create_new_decision_task == True + assert len(call_args.decisions) == 2 + # query_results should not be set when None + assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_no_query_results(self, handler, sample_decision_task): + """Test decision task completion response without query results.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [] + decision_result.force_create_new_decision_task = False + decision_result.query_results = None + + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + assert call_args.return_new_decision_task == False + assert call_args.force_create_new_decision_task == False + assert len(call_args.decisions) == 0 + # query_results should not be set when None + assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 + + @pytest.mark.asyncio + async def test_respond_decision_task_completed_error(self, handler, sample_decision_task): + """Test decision task completion response error handling.""" + decision_result = Mock(spec=DecisionResult) + decision_result.decisions = [] + decision_result.force_create_new_decision_task = False + decision_result.query_results = {} + + handler._client.worker_stub.RespondDecisionTaskCompleted.side_effect = Exception("Respond failed") + + with pytest.raises(Exception, match="Respond failed"): + await handler._respond_decision_task_completed(sample_decision_task, decision_result) + + def test_cleanup_workflow_engine(self, handler): + """Test workflow engine cleanup.""" + # Add some mock engines + handler._workflow_engines["workflow1:run1"] = Mock() + handler._workflow_engines["workflow2:run2"] = Mock() + + # Clean up one engine + handler.cleanup_workflow_engine("workflow1", "run1") + + # Verify only one engine was removed + assert len(handler._workflow_engines) == 1 + assert "workflow1:run1" not in handler._workflow_engines + assert "workflow2:run2" in handler._workflow_engines + + def test_cleanup_workflow_engine_not_found(self, handler): + """Test cleanup of non-existent workflow engine.""" + # Should not raise error + handler.cleanup_workflow_engine("nonexistent", "run") + + # Should not affect existing engines + assert len(handler._workflow_engines) == 0 + + @pytest.mark.asyncio + async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): + """Test that WorkflowEngine is created with correct WorkflowInfo.""" + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: + with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: + await handler._handle_task_implementation(sample_decision_task) + + # Verify WorkflowInfo was created with correct parameters + mock_workflow_info_class.assert_called_once_with( + workflow_type="TestWorkflow", + workflow_domain="test_domain", + workflow_id="test_workflow_id", + workflow_run_id="test_run_id" + ) + + # Verify WorkflowEngine was created with correct parameters + mock_workflow_engine_class.assert_called_once() + call_args = mock_workflow_engine_class.call_args + assert call_args[1]['info'] is not None + assert call_args[1]['client'] == handler._client + assert call_args[1]['workflow_func'] == mock_workflow_func diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py new file mode 100644 index 0000000..def715f --- /dev/null +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +""" +Integration tests for task handlers. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch, PropertyMock + +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse +from cadence.client import Client +from cadence.worker._decision_task_handler import DecisionTaskHandler +from cadence.worker._registry import Registry +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult + + +class TestTaskHandlerIntegration: + """Integration tests for task handlers.""" + + @pytest.fixture + def mock_client(self): + """Create a mock client.""" + client = Mock(spec=Client) + client.worker_stub = Mock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + type(client).domain = PropertyMock(return_value="test_domain") + return client + + @pytest.fixture + def mock_registry(self): + """Create a mock registry.""" + registry = Mock(spec=Registry) + return registry + + @pytest.fixture + def handler(self, mock_client, mock_registry): + """Create a DecisionTaskHandler instance.""" + return DecisionTaskHandler( + client=mock_client, + task_list="test_task_list", + registry=mock_registry, + identity="test_identity" + ) + + @pytest.fixture + def sample_decision_task(self): + """Create a sample decision task.""" + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = b"test_task_token" + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = "test_workflow_id" + task.workflow_execution.run_id = "test_run_id" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + return task + + @pytest.mark.asyncio + async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): + """Test the complete task handling flow from base handler through decision handler.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Use the base handler's handle_task method + await handler.handle_task(sample_decision_task) + + # Verify the complete flow + mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + mock_engine.process_decision.assert_called_once_with(sample_decision_task) + handler._client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + @pytest.mark.asyncio + async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): + """Test the complete task handling flow when an error occurs.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine to raise an error + mock_engine = Mock(spec=WorkflowEngine) + mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Use the base handler's handle_task method + await handler.handle_task(sample_decision_task) + + # Verify error handling + handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + call_args = handler._client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.task_token == sample_decision_task.task_token + assert call_args.identity == handler._identity + + @pytest.mark.asyncio + async def test_context_propagation_integration(self, handler, sample_decision_task, mock_registry): + """Test that context propagation works correctly in the integration.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + # Track if context methods are called + context_propagated = False + context_unset = False + + async def track_propagate_context(task): + nonlocal context_propagated + context_propagated = True + + async def track_unset_current_context(): + nonlocal context_unset + context_unset = True + + handler._propagate_context = track_propagate_context + handler._unset_current_context = track_unset_current_context + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + await handler.handle_task(sample_decision_task) + + # Verify context methods were called + assert context_propagated + assert context_unset + + @pytest.mark.asyncio + async def test_multiple_workflow_executions(self, handler, mock_registry): + """Test handling multiple workflow executions with different engines.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Create multiple decision tasks for different workflows + task1 = Mock(spec=PollForDecisionTaskResponse) + task1.task_token = b"task1_token" + task1.workflow_execution = Mock() + task1.workflow_execution.workflow_id = "workflow1" + task1.workflow_execution.run_id = "run1" + task1.workflow_type = Mock() + task1.workflow_type.name = "TestWorkflow" + + task2 = Mock(spec=PollForDecisionTaskResponse) + task2.task_token = b"task2_token" + task2.workflow_execution = Mock() + task2.workflow_execution.workflow_id = "workflow2" + task2.workflow_execution.run_id = "run2" + task2.workflow_type = Mock() + task2.workflow_type.name = "TestWorkflow" + + # Mock workflow engines + mock_engine1 = Mock(spec=WorkflowEngine) + mock_engine2 = Mock(spec=WorkflowEngine) + + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + + mock_engine1.process_decision = AsyncMock(return_value=mock_decision_result) + mock_engine2.process_decision = AsyncMock(return_value=mock_decision_result) + + def mock_workflow_engine_creator(*args, **kwargs): + # Return different engines based on workflow info + workflow_info = kwargs.get('info') + if workflow_info and workflow_info.workflow_id == "workflow1": + return mock_engine1 + else: + return mock_engine2 + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', side_effect=mock_workflow_engine_creator): + # Process both tasks + await handler.handle_task(task1) + await handler.handle_task(task2) + + # Verify both engines were created and used + assert len(handler._workflow_engines) == 2 + assert "workflow1:run1" in handler._workflow_engines + assert "workflow2:run2" in handler._workflow_engines + + # Verify both engines were called + mock_engine1.process_decision.assert_called_once_with(task1) + mock_engine2.process_decision.assert_called_once_with(task2) + + @pytest.mark.asyncio + async def test_workflow_engine_cleanup_integration(self, handler, sample_decision_task, mock_registry): + """Test workflow engine cleanup integration.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Process task to create engine + await handler.handle_task(sample_decision_task) + + # Verify engine was created + assert len(handler._workflow_engines) == 1 + assert "test_workflow_id:test_run_id" in handler._workflow_engines + + # Clean up engine + handler.cleanup_workflow_engine("test_workflow_id", "test_run_id") + + # Verify engine was cleaned up + assert len(handler._workflow_engines) == 0 + + @pytest.mark.asyncio + async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): + """Test that context cleanup happens even when errors occur.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Mock workflow engine to raise an error + mock_engine = Mock(spec=WorkflowEngine) + mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) + + # Track context cleanup + context_unset = False + + async def track_unset_current_context(): + nonlocal context_unset + context_unset = True + + handler._unset_current_context = track_unset_current_context + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + await handler.handle_task(sample_decision_task) + + # Verify context was cleaned up even after error + assert context_unset + + # Verify error was handled + handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + @pytest.mark.asyncio + async def test_concurrent_task_handling(self, handler, mock_registry): + """Test handling multiple tasks concurrently.""" + import asyncio + + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Create multiple tasks + tasks = [] + for i in range(3): + task = Mock(spec=PollForDecisionTaskResponse) + task.task_token = f"task{i}_token".encode() + task.workflow_execution = Mock() + task.workflow_execution.workflow_id = f"workflow{i}" + task.workflow_execution.run_id = f"run{i}" + task.workflow_type = Mock() + task.workflow_type.name = "TestWorkflow" + tasks.append(task) + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_decision_result.force_create_new_decision_task = False + mock_decision_result.query_results = {} + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Process all tasks concurrently + await asyncio.gather(*[handler.handle_task(task) for task in tasks]) + + # Verify all tasks were processed + assert mock_engine.process_decision.call_count == 3 + assert handler._client.worker_stub.RespondDecisionTaskCompleted.call_count == 3 + + # Verify engines were created for each workflow + assert len(handler._workflow_engines) == 3 From 5d759c3a75534e0eb7fdd26afd6d5687e0c35e5d Mon Sep 17 00:00:00 2001 From: Tim Li Date: Fri, 12 Sep 2025 17:22:40 -0700 Subject: [PATCH 03/10] lint Signed-off-by: Tim Li --- cadence/_internal/workflow/workflow_engine.py | 8 ++++++-- cadence/worker/_base_task_handler.py | 2 +- cadence/worker/_decision_task_handler.py | 2 +- tests/cadence/worker/test_base_task_handler.py | 3 +-- tests/cadence/worker/test_decision_task_handler.py | 13 +++++-------- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 8c2f8da..07ee08f 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional, Callable, Any from cadence._internal.workflow.context import Context from cadence.api.v1.decision_pb2 import Decision @@ -10,12 +11,15 @@ @dataclass class DecisionResult: decisions: list[Decision] + force_create_new_decision_task: bool = False + query_results: Optional[dict] = None class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client): + def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Optional[Callable[..., Any]] = None): self._context = Context(client, info) + self._workflow_func = workflow_func # TODO: Implement this - def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: + async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: with self._context._activate(): return DecisionResult(decisions=[]) diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py index 751b2bb..844ef5d 100644 --- a/cadence/worker/_base_task_handler.py +++ b/cadence/worker/_base_task_handler.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import TypeVar, Generic +from typing import Any, Dict, TypeVar, Generic logger = logging.getLogger(__name__) diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 4ebc719..462d362 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Any +from typing import Dict from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import ( diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py index d8b0004..55a5bbb 100644 --- a/tests/cadence/worker/test_base_task_handler.py +++ b/tests/cadence/worker/test_base_task_handler.py @@ -4,8 +4,7 @@ """ import pytest -from unittest.mock import Mock, AsyncMock, patch -from typing import Any +from unittest.mock import Mock from cadence.worker._base_task_handler import BaseTaskHandler diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index 96e7846..bd6704d 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -5,20 +5,17 @@ import pytest from unittest.mock import Mock, AsyncMock, patch, PropertyMock -from typing import Dict, Any from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import ( PollForDecisionTaskResponse, - RespondDecisionTaskCompletedRequest, - RespondDecisionTaskFailedRequest + RespondDecisionTaskCompletedRequest ) from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause from cadence.api.v1.decision_pb2 import Decision from cadence.client import Client from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry -from cadence.workflow import WorkflowInfo from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult @@ -268,8 +265,8 @@ async def test_respond_decision_task_completed_success(self, handler, sample_dec assert isinstance(call_args, RespondDecisionTaskCompletedRequest) assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity - assert call_args.return_new_decision_task == True - assert call_args.force_create_new_decision_task == True + assert call_args.return_new_decision_task + assert call_args.force_create_new_decision_task assert len(call_args.decisions) == 2 # query_results should not be set when None assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 @@ -285,8 +282,8 @@ async def test_respond_decision_task_completed_no_query_results(self, handler, s await handler._respond_decision_task_completed(sample_decision_task, decision_result) call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] - assert call_args.return_new_decision_task == False - assert call_args.force_create_new_decision_task == False + assert not call_args.return_new_decision_task + assert not call_args.force_create_new_decision_task assert len(call_args.decisions) == 0 # query_results should not be set when None assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 From 75b86641b60edd0bd7f26c1b257e8e1b7d4a1b25 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 10:02:56 -0700 Subject: [PATCH 04/10] improve context management Signed-off-by: Tim Li --- cadence/worker/_base_task_handler.py | 26 +----- cadence/worker/_decision_task_handler.py | 25 ++++-- .../cadence/worker/test_base_task_handler.py | 86 +------------------ .../worker/test_decision_task_handler.py | 40 +++------ .../worker/test_task_handler_integration.py | 56 ++++++------ 5 files changed, 59 insertions(+), 174 deletions(-) diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py index 844ef5d..2c23ed0 100644 --- a/cadence/worker/_base_task_handler.py +++ b/cadence/worker/_base_task_handler.py @@ -34,7 +34,6 @@ async def handle_task(self, task: T) -> None: Handle a single task. This method provides the base implementation for task handling that includes: - - Context propagation - Error handling - Cleanup @@ -42,18 +41,12 @@ async def handle_task(self, task: T) -> None: task: The task to handle """ try: - # Propagate context from task parameters - await self._propagate_context(task) - - # Handle the task + # Handle the task implementation await self._handle_task_implementation(task) except Exception as e: logger.exception(f"Error handling task: {e}") await self.handle_task_failure(task, e) - finally: - # Clean up context - await self._unset_current_context() @abstractmethod async def _handle_task_implementation(self, task: T) -> None: @@ -75,20 +68,3 @@ async def handle_task_failure(self, task: T, error: Exception) -> None: error: The exception that occurred """ pass - - async def _propagate_context(self, task: T) -> None: - """ - Propagate context from task parameters. - - Args: - task: The task containing context information - """ - # Default implementation - subclasses should override if needed - pass - - async def _unset_current_context(self) -> None: - """ - Unset the current context after task completion. - """ - # Default implementation - subclasses should override if needed - pass diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 462d362..4f35212 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -11,6 +11,7 @@ from cadence.client import Client from cadence.worker._base_task_handler import BaseTaskHandler from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence._internal.workflow.context import Context from cadence.workflow import WorkflowInfo from cadence.worker._registry import Registry @@ -52,8 +53,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - if not workflow_execution or not workflow_type: logger.error("Decision task missing workflow execution or type") - await self.handle_task_failure(task, ValueError("Missing workflow execution or type")) - return + raise ValueError("Missing workflow execution or type") workflow_id = workflow_execution.workflow_id run_id = workflow_execution.run_id @@ -69,8 +69,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_func = self._registry.get_workflow(workflow_type_name) except KeyError: logger.error(f"Workflow type '{workflow_type_name}' not found in registry") - await self.handle_task_failure(task, KeyError(f"Workflow type '{workflow_type_name}' not found")) - return + raise KeyError(f"Workflow type '{workflow_type_name}' not found") # Create workflow info and engine workflow_info = WorkflowInfo( @@ -86,12 +85,22 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_func=workflow_func ) - # Process the decision using the workflow engine + # Create workflow context and execute with it active workflow_engine = self._workflow_engines[engine_key] - decision_result = await workflow_engine.process_decision(task) + workflow_info = WorkflowInfo( + workflow_type=workflow_type_name, + workflow_domain=self._client.domain, + workflow_id=workflow_id, + workflow_run_id=run_id + ) - # Respond with the decisions - await self._respond_decision_task_completed(task, decision_result) + context = Context(client=self._client, info=workflow_info) + with context._activate(): + # Process the decision using the workflow engine + decision_result = await workflow_engine.process_decision(task) + + # Respond with the decisions + await self._respond_decision_task_completed(task, decision_result) logger.info(f"Successfully processed decision task for workflow {workflow_id}") diff --git a/tests/cadence/worker/test_base_task_handler.py b/tests/cadence/worker/test_base_task_handler.py index 55a5bbb..d5d48a6 100644 --- a/tests/cadence/worker/test_base_task_handler.py +++ b/tests/cadence/worker/test_base_task_handler.py @@ -16,8 +16,6 @@ def __init__(self, client, task_list: str, identity: str, **options): super().__init__(client, task_list, identity, **options) self._handle_task_implementation_called = False self._handle_task_failure_called = False - self._propagate_context_called = False - self._unset_current_context_called = False self._last_task: str = "" self._last_error: Exception | None = None @@ -33,15 +31,6 @@ async def handle_task_failure(self, task: str, error: Exception) -> None: self._handle_task_failure_called = True self._last_task = task self._last_error = error - - async def _propagate_context(self, task: str) -> None: - """Test implementation of context propagation.""" - self._propagate_context_called = True - self._last_task = task - - async def _unset_current_context(self) -> None: - """Test implementation of context cleanup.""" - self._unset_current_context_called = True class TestBaseTaskHandler: @@ -71,10 +60,8 @@ async def test_handle_task_success(self): await handler.handle_task("test_task") - # Verify all methods were called in correct order - assert handler._propagate_context_called + # Verify implementation was called assert handler._handle_task_implementation_called - assert handler._unset_current_context_called assert not handler._handle_task_failure_called assert handler._last_task == "test_task" assert handler._last_error is None @@ -88,72 +75,12 @@ async def test_handle_task_failure(self): await handler.handle_task("raise_error") # Verify error handling was called - assert handler._propagate_context_called assert handler._handle_task_implementation_called assert handler._handle_task_failure_called - assert handler._unset_current_context_called assert handler._last_task == "raise_error" assert isinstance(handler._last_error, ValueError) assert str(handler._last_error) == "Test error" - @pytest.mark.asyncio - async def test_handle_task_with_context_propagation_error(self): - """Test task handling when context propagation fails.""" - client = Mock() - handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - - # Override _propagate_context to raise an error - async def failing_propagate_context(task): - raise RuntimeError("Context propagation failed") - - # Use setattr to avoid mypy error about method assignment - setattr(handler, '_propagate_context', failing_propagate_context) - - await handler.handle_task("test_task") - - # Verify error handling was called - assert handler._handle_task_failure_called - assert handler._unset_current_context_called - assert isinstance(handler._last_error, RuntimeError) - assert str(handler._last_error) == "Context propagation failed" - - @pytest.mark.asyncio - async def test_handle_task_with_cleanup_error(self): - """Test task handling when cleanup fails.""" - client = Mock() - handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - - # Override _unset_current_context to raise an error - async def failing_unset_context(): - raise RuntimeError("Cleanup failed") - - # Use setattr to avoid mypy error about method assignment - setattr(handler, '_unset_current_context', failing_unset_context) - - # Cleanup errors in finally block will propagate - with pytest.raises(RuntimeError, match="Cleanup failed"): - await handler.handle_task("test_task") - - @pytest.mark.asyncio - async def test_handle_task_with_implementation_and_cleanup_errors(self): - """Test task handling when both implementation and cleanup fail.""" - client = Mock() - handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - - # Override _unset_current_context to raise an error - async def failing_unset_context(): - raise RuntimeError("Cleanup failed") - - # Use setattr to avoid mypy error about method assignment - setattr(handler, '_unset_current_context', failing_unset_context) - - # The implementation error should be handled, but cleanup error will propagate - with pytest.raises(RuntimeError, match="Cleanup failed"): - await handler.handle_task("raise_error") - - # Verify the implementation error was handled before cleanup error - assert handler._handle_task_failure_called - assert isinstance(handler._last_error, ValueError) @pytest.mark.asyncio async def test_abstract_methods_not_implemented(self): @@ -175,17 +102,6 @@ async def handle_task_failure(self, task: str, error: Exception) -> None: with pytest.raises(NotImplementedError): await handler.handle_task_failure("test", Exception("test")) - @pytest.mark.asyncio - async def test_default_context_methods(self): - """Test default implementations of context methods.""" - client = Mock() - handler = ConcreteTaskHandler(client, "test_task_list", "test_identity") - - # Test default _propagate_context (should not raise) - await handler._propagate_context("test_task") - - # Test default _unset_current_context (should not raise) - await handler._unset_current_context() @pytest.mark.asyncio async def test_generic_type_parameter(self): diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index bd6704d..b50a50a 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -114,14 +114,8 @@ async def test_handle_task_implementation_missing_workflow_execution(self, handl task.workflow_type = Mock() task.workflow_type.name = "TestWorkflow" - with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + with pytest.raises(ValueError, match="Missing workflow execution or type"): await handler._handle_task_implementation(task) - - mock_handle_failure.assert_called_once() - args = mock_handle_failure.call_args[0] - assert args[0] == task - assert isinstance(args[1], ValueError) - assert "Missing workflow execution or type" in str(args[1]) @pytest.mark.asyncio async def test_handle_task_implementation_missing_workflow_type(self, handler): @@ -133,28 +127,16 @@ async def test_handle_task_implementation_missing_workflow_type(self, handler): task.workflow_execution.run_id = "test_run_id" task.workflow_type = None - with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + with pytest.raises(ValueError, match="Missing workflow execution or type"): await handler._handle_task_implementation(task) - - mock_handle_failure.assert_called_once() - args = mock_handle_failure.call_args[0] - assert args[0] == task - assert isinstance(args[1], ValueError) - assert "Missing workflow execution or type" in str(args[1]) @pytest.mark.asyncio async def test_handle_task_implementation_workflow_not_found(self, handler, sample_decision_task, mock_registry): """Test decision task handling when workflow is not found in registry.""" mock_registry.get_workflow.side_effect = KeyError("Workflow not found") - with patch.object(handler, 'handle_task_failure', new_callable=AsyncMock) as mock_handle_failure: + with pytest.raises(KeyError, match="Workflow type 'TestWorkflow' not found"): await handler._handle_task_implementation(sample_decision_task) - - mock_handle_failure.assert_called_once() - args = mock_handle_failure.call_args[0] - assert args[0] == sample_decision_task - assert isinstance(args[1], KeyError) - assert "Workflow type 'TestWorkflow' not found" in str(args[1]) @pytest.mark.asyncio async def test_handle_task_implementation_reuses_existing_engine(self, handler, sample_decision_task, mock_registry): @@ -340,13 +322,15 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: await handler._handle_task_implementation(sample_decision_task) - # Verify WorkflowInfo was created with correct parameters - mock_workflow_info_class.assert_called_once_with( - workflow_type="TestWorkflow", - workflow_domain="test_domain", - workflow_id="test_workflow_id", - workflow_run_id="test_run_id" - ) + # Verify WorkflowInfo was created with correct parameters (called twice - once for engine, once for context) + assert mock_workflow_info_class.call_count == 2 + for call in mock_workflow_info_class.call_args_list: + assert call[1] == { + 'workflow_type': "TestWorkflow", + 'workflow_domain': "test_domain", + 'workflow_id': "test_workflow_id", + 'workflow_run_id': "test_run_id" + } # Verify WorkflowEngine was created with correct parameters mock_workflow_engine_class.assert_called_once() diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index def715f..58837a3 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -4,6 +4,7 @@ """ import pytest +from contextlib import contextmanager from unittest.mock import Mock, AsyncMock, patch, PropertyMock from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse @@ -100,8 +101,8 @@ async def test_full_task_handling_flow_with_error(self, handler, sample_decision assert call_args.identity == handler._identity @pytest.mark.asyncio - async def test_context_propagation_integration(self, handler, sample_decision_task, mock_registry): - """Test that context propagation works correctly in the integration.""" + async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): + """Test that context activation works correctly in the integration.""" # Mock workflow function mock_workflow_func = Mock() mock_registry.get_workflow.return_value = mock_workflow_func @@ -114,27 +115,23 @@ async def test_context_propagation_integration(self, handler, sample_decision_ta mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - # Track if context methods are called - context_propagated = False - context_unset = False + # Track if context is activated + context_activated = False - async def track_propagate_context(task): - nonlocal context_propagated - context_propagated = True - - async def track_unset_current_context(): - nonlocal context_unset - context_unset = True - - handler._propagate_context = track_propagate_context - handler._unset_current_context = track_unset_current_context + def track_context_activation(): + nonlocal context_activated + context_activated = True with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - await handler.handle_task(sample_decision_task) - - # Verify context methods were called - assert context_propagated - assert context_unset + with patch('cadence.worker._decision_task_handler.Context') as mock_context_class: + mock_context = Mock() + mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_activation())()) + mock_context_class.return_value = mock_context + + await handler.handle_task(sample_decision_task) + + # Verify context was activated + assert context_activated @pytest.mark.asyncio async def test_multiple_workflow_executions(self, handler, mock_registry): @@ -235,19 +232,22 @@ async def test_error_handling_with_context_cleanup(self, handler, sample_decisio mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) # Track context cleanup - context_unset = False - - async def track_unset_current_context(): - nonlocal context_unset - context_unset = True + context_cleaned_up = False - handler._unset_current_context = track_unset_current_context + def track_context_cleanup(): + nonlocal context_cleaned_up + context_cleaned_up = True with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - await handler.handle_task(sample_decision_task) + with patch('cadence.worker._decision_task_handler.Context') as mock_context_class: + mock_context = Mock() + mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_cleanup())()) + mock_context_class.return_value = mock_context + + await handler.handle_task(sample_decision_task) # Verify context was cleaned up even after error - assert context_unset + assert context_cleaned_up # Verify error was handled handler._client.worker_stub.RespondDecisionTaskFailed.assert_called_once() From ea56208f31baa92c9b7722b03ffb845eb0f2879a Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 10:35:08 -0700 Subject: [PATCH 05/10] lint Signed-off-by: Tim Li --- cadence/worker/_base_task_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cadence/worker/_base_task_handler.py b/cadence/worker/_base_task_handler.py index 2c23ed0..3fda7e7 100644 --- a/cadence/worker/_base_task_handler.py +++ b/cadence/worker/_base_task_handler.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, TypeVar, Generic +from typing import TypeVar, Generic logger = logging.getLogger(__name__) From 08b0e3b462b9070fffc69bf13eaedba7f9ea6b6c Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 13:24:46 -0700 Subject: [PATCH 06/10] remove context activate in handler Signed-off-by: Tim Li --- cadence/worker/_decision_task_handler.py | 17 ++++------------- .../worker/test_decision_task_handler.py | 4 ++-- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 4f35212..ffa1089 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -87,20 +87,11 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - # Create workflow context and execute with it active workflow_engine = self._workflow_engines[engine_key] - workflow_info = WorkflowInfo( - workflow_type=workflow_type_name, - workflow_domain=self._client.domain, - workflow_id=workflow_id, - workflow_run_id=run_id - ) - context = Context(client=self._client, info=workflow_info) - with context._activate(): - # Process the decision using the workflow engine - decision_result = await workflow_engine.process_decision(task) - - # Respond with the decisions - await self._respond_decision_task_completed(task, decision_result) + decision_result = await workflow_engine.process_decision(task) + + # Respond with the decisions + await self._respond_decision_task_completed(task, decision_result) logger.info(f"Successfully processed decision task for workflow {workflow_id}") diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index b50a50a..7b6f0c5 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -322,8 +322,8 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: await handler._handle_task_implementation(sample_decision_task) - # Verify WorkflowInfo was created with correct parameters (called twice - once for engine, once for context) - assert mock_workflow_info_class.call_count == 2 + # Verify WorkflowInfo was created with correct parameters (called once for engine) + assert mock_workflow_info_class.call_count == 1 for call in mock_workflow_info_class.call_args_list: assert call[1] == { 'workflow_type': "TestWorkflow", From 918d128a303f4382f8547bafc799c6662c9cced0 Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 13:27:20 -0700 Subject: [PATCH 07/10] lint Signed-off-by: Tim Li --- cadence/worker/_decision_task_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index ffa1089..96f7c97 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -11,7 +11,6 @@ from cadence.client import Client from cadence.worker._base_task_handler import BaseTaskHandler from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult -from cadence._internal.workflow.context import Context from cadence.workflow import WorkflowInfo from cadence.worker._registry import Registry From 9f9c2895ff90de1581368b0a373ab29c409caabd Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 13:30:25 -0700 Subject: [PATCH 08/10] fix test Signed-off-by: Tim Li --- tests/cadence/worker/test_task_handler_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index 58837a3..c94370f 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -123,7 +123,7 @@ def track_context_activation(): context_activated = True with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - with patch('cadence.worker._decision_task_handler.Context') as mock_context_class: + with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: mock_context = Mock() mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_activation())()) mock_context_class.return_value = mock_context @@ -239,7 +239,7 @@ def track_context_cleanup(): context_cleaned_up = True with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): - with patch('cadence.worker._decision_task_handler.Context') as mock_context_class: + with patch('cadence._internal.workflow.workflow_engine.Context') as mock_context_class: mock_context = Mock() mock_context._activate = Mock(return_value=contextmanager(lambda: track_context_cleanup())()) mock_context_class.return_value = mock_context From 7d6eaecbb0b358ac6cd999ab90feb4e6c26a379e Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 16:12:49 -0700 Subject: [PATCH 09/10] respond to comments Signed-off-by: Tim Li --- cadence/_internal/workflow/workflow_engine.py | 2 - cadence/worker/_decision_task_handler.py | 104 +++++++----------- .../_internal/test_decision_state_machine.py | 4 - .../worker/test_decision_task_handler.py | 57 ++-------- .../worker/test_task_handler_integration.py | 62 +++-------- 5 files changed, 65 insertions(+), 164 deletions(-) diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 07ee08f..00fac2c 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -11,8 +11,6 @@ @dataclass class DecisionResult: decisions: list[Decision] - force_create_new_decision_task: bool = False - query_results: Optional[dict] = None class WorkflowEngine: def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Optional[Callable[..., Any]] = None): diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 96f7c97..5d4c4ae 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -36,7 +36,7 @@ def __init__(self, client: Client, task_list: str, registry: Registry, identity: """ super().__init__(client, task_list, identity, **options) self._registry = registry - self._workflow_engines: Dict[str, WorkflowEngine] = {} + self._workflow_engine: WorkflowEngine async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: @@ -51,7 +51,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_type = task.workflow_type if not workflow_execution or not workflow_type: - logger.error("Decision task missing workflow execution or type") + logger.error("Decision task missing workflow execution or type. Task: %r", task) raise ValueError("Missing workflow execution or type") workflow_id = workflow_execution.workflow_id @@ -60,34 +60,27 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") - # Get or create workflow engine for this workflow execution - engine_key = f"{workflow_id}:{run_id}" - if engine_key not in self._workflow_engines: - # Get the workflow function from registry - try: - workflow_func = self._registry.get_workflow(workflow_type_name) - except KeyError: - logger.error(f"Workflow type '{workflow_type_name}' not found in registry") - raise KeyError(f"Workflow type '{workflow_type_name}' not found") - - # Create workflow info and engine - workflow_info = WorkflowInfo( - workflow_type=workflow_type_name, - workflow_domain=self._client.domain, - workflow_id=workflow_id, - workflow_run_id=run_id - ) - - self._workflow_engines[engine_key] = WorkflowEngine( - info=workflow_info, - client=self._client, - workflow_func=workflow_func - ) + try: + workflow_func = self._registry.get_workflow(workflow_type_name) + except KeyError: + logger.error(f"Workflow type '{workflow_type_name}' not found in registry") + raise KeyError(f"Workflow type '{workflow_type_name}' not found") + + # Create workflow info and engine + workflow_info = WorkflowInfo( + workflow_type=workflow_type_name, + workflow_domain=self._client.domain, + workflow_id=workflow_id, + workflow_run_id=run_id + ) - # Create workflow context and execute with it active - workflow_engine = self._workflow_engines[engine_key] + self._workflow_engine = WorkflowEngine( + info=workflow_info, + client=self._client, + workflow_func=workflow_func + ) - decision_result = await workflow_engine.process_decision(task) + decision_result = await self._workflow_engine.process_decision(task) # Respond with the decisions await self._respond_decision_task_completed(task, decision_result) @@ -102,21 +95,22 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex task: The task that failed error: The exception that occurred """ + logger.error(f"Decision task failed: {error}") + + # Determine the failure cause + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION + if isinstance(error, KeyError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + elif isinstance(error, ValueError): + cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES + + # Create error details + # TODO: Use a data converter for error details serialization + error_message = str(error).encode('utf-8') + details = Payload(data=error_message) + + # Respond with failure try: - logger.error(f"Decision task failed: {error}") - - # Determine the failure cause - cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION - if isinstance(error, KeyError): - cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE - elif isinstance(error, ValueError): - cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES - - # Create error details - error_message = str(error).encode('utf-8') - details = Payload(data=error_message) - - # Respond with failure await self._client.worker_stub.RespondDecisionTaskFailed( RespondDecisionTaskFailedRequest( task_token=task.task_token, @@ -125,11 +119,10 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex details=details ) ) - logger.info("Decision task failure response sent") - except Exception: - logger.exception("Error handling decision task failure") + logger.exception("Error responding to decision task failure") + async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: """ @@ -144,30 +137,13 @@ async def _respond_decision_task_completed(self, task: PollForDecisionTaskRespon task_token=task.task_token, decisions=decision_result.decisions, identity=self._identity, - return_new_decision_task=decision_result.force_create_new_decision_task, - force_create_new_decision_task=decision_result.force_create_new_decision_task + return_new_decision_task=True, + force_create_new_decision_task=False ) - # Add query results if present - if decision_result.query_results: - request.query_results.update(decision_result.query_results) - await self._client.worker_stub.RespondDecisionTaskCompleted(request) logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") except Exception: logger.exception("Error responding to decision task completion") raise - - def cleanup_workflow_engine(self, workflow_id: str, run_id: str) -> None: - """ - Clean up a workflow engine when workflow execution is complete. - - Args: - workflow_id: The workflow ID - run_id: The run ID - """ - engine_key = f"{workflow_id}:{run_id}" - if engine_key in self._workflow_engines: - del self._workflow_engines[engine_key] - logger.debug(f"Cleaned up workflow engine for {workflow_id}:{run_id}") diff --git a/tests/cadence/_internal/test_decision_state_machine.py b/tests/cadence/_internal/test_decision_state_machine.py index 1cd0d92..4f61dca 100644 --- a/tests/cadence/_internal/test_decision_state_machine.py +++ b/tests/cadence/_internal/test_decision_state_machine.py @@ -439,7 +439,3 @@ def test_manager_aggregates_and_routes(): ), ) ) - - assert a.status is DecisionState.COMPLETED - assert t.status is DecisionState.COMPLETED - assert c.status is DecisionState.COMPLETED diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index 7b6f0c5..2fc98ec 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -75,8 +75,6 @@ def test_initialization(self, mock_client, mock_registry): assert handler._identity == "test_identity" assert handler._registry == mock_registry assert handler._options == {"option1": "value1"} - assert isinstance(handler._workflow_engines, dict) - assert len(handler._workflow_engines) == 0 @pytest.mark.asyncio async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): @@ -139,8 +137,8 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp await handler._handle_task_implementation(sample_decision_task) @pytest.mark.asyncio - async def test_handle_task_implementation_reuses_existing_engine(self, handler, sample_decision_task, mock_registry): - """Test that decision task handler reuses existing workflow engine.""" + async def test_handle_task_implementation_creates_new_engine(self, handler, sample_decision_task, mock_registry): + """Test that decision task handler creates new workflow engine for each task.""" # Mock workflow function mock_workflow_func = Mock() mock_registry.get_workflow.return_value = mock_workflow_func @@ -153,23 +151,19 @@ async def test_handle_task_implementation_reuses_existing_engine(self, handler, mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: # First call - should create new engine await handler._handle_task_implementation(sample_decision_task) - # Second call - should reuse existing engine + # Second call - should create another new engine await handler._handle_task_implementation(sample_decision_task) - # Registry should only be called once - mock_registry.get_workflow.assert_called_once_with("TestWorkflow") + # Registry should be called for each task + assert mock_registry.get_workflow.call_count == 2 - # Engine should be called twice + # Engine should be created twice and called twice + assert mock_engine_class.call_count == 2 assert mock_engine.process_decision.call_count == 2 - - # Should have one engine in the cache - assert len(handler._workflow_engines) == 1 - engine_key = "test_workflow_id:test_run_id" - assert engine_key in handler._workflow_engines @pytest.mark.asyncio async def test_handle_task_failure_keyerror(self, handler, sample_decision_task): @@ -237,8 +231,6 @@ async def test_respond_decision_task_completed_success(self, handler, sample_dec """Test successful decision task completion response.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [Decision(), Decision()] - decision_result.force_create_new_decision_task = True - decision_result.query_results = None # Test without query results first await handler._respond_decision_task_completed(sample_decision_task, decision_result) @@ -248,62 +240,33 @@ async def test_respond_decision_task_completed_success(self, handler, sample_dec assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity assert call_args.return_new_decision_task - assert call_args.force_create_new_decision_task + assert not call_args.force_create_new_decision_task assert len(call_args.decisions) == 2 - # query_results should not be set when None - assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 @pytest.mark.asyncio async def test_respond_decision_task_completed_no_query_results(self, handler, sample_decision_task): """Test decision task completion response without query results.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [] - decision_result.force_create_new_decision_task = False - decision_result.query_results = None await handler._respond_decision_task_completed(sample_decision_task, decision_result) call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] - assert not call_args.return_new_decision_task + assert call_args.return_new_decision_task assert not call_args.force_create_new_decision_task assert len(call_args.decisions) == 0 - # query_results should not be set when None - assert not hasattr(call_args, 'query_results') or len(call_args.query_results) == 0 @pytest.mark.asyncio async def test_respond_decision_task_completed_error(self, handler, sample_decision_task): """Test decision task completion response error handling.""" decision_result = Mock(spec=DecisionResult) decision_result.decisions = [] - decision_result.force_create_new_decision_task = False - decision_result.query_results = {} handler._client.worker_stub.RespondDecisionTaskCompleted.side_effect = Exception("Respond failed") with pytest.raises(Exception, match="Respond failed"): await handler._respond_decision_task_completed(sample_decision_task, decision_result) - def test_cleanup_workflow_engine(self, handler): - """Test workflow engine cleanup.""" - # Add some mock engines - handler._workflow_engines["workflow1:run1"] = Mock() - handler._workflow_engines["workflow2:run2"] = Mock() - - # Clean up one engine - handler.cleanup_workflow_engine("workflow1", "run1") - - # Verify only one engine was removed - assert len(handler._workflow_engines) == 1 - assert "workflow1:run1" not in handler._workflow_engines - assert "workflow2:run2" in handler._workflow_engines - - def test_cleanup_workflow_engine_not_found(self, handler): - """Test cleanup of non-existent workflow engine.""" - # Should not raise error - handler.cleanup_workflow_engine("nonexistent", "run") - - # Should not affect existing engines - assert len(handler._workflow_engines) == 0 @pytest.mark.asyncio async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index c94370f..64d877f 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -66,8 +66,6 @@ async def test_full_task_handling_flow_success(self, handler, sample_decision_ta mock_engine = Mock(spec=WorkflowEngine) mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): @@ -111,8 +109,6 @@ async def test_context_activation_integration(self, handler, sample_decision_tas mock_engine = Mock(spec=WorkflowEngine) mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) # Track if context is activated @@ -135,7 +131,7 @@ def track_context_activation(): @pytest.mark.asyncio async def test_multiple_workflow_executions(self, handler, mock_registry): - """Test handling multiple workflow executions with different engines.""" + """Test handling multiple workflow executions creates new engines for each.""" # Mock workflow function mock_workflow_func = Mock() mock_registry.get_workflow.return_value = mock_workflow_func @@ -157,43 +153,28 @@ async def test_multiple_workflow_executions(self, handler, mock_registry): task2.workflow_type = Mock() task2.workflow_type.name = "TestWorkflow" - # Mock workflow engines - mock_engine1 = Mock(spec=WorkflowEngine) - mock_engine2 = Mock(spec=WorkflowEngine) + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} - mock_engine1.process_decision = AsyncMock(return_value=mock_decision_result) - mock_engine2.process_decision = AsyncMock(return_value=mock_decision_result) - - def mock_workflow_engine_creator(*args, **kwargs): - # Return different engines based on workflow info - workflow_info = kwargs.get('info') - if workflow_info and workflow_info.workflow_id == "workflow1": - return mock_engine1 - else: - return mock_engine2 + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - with patch('cadence.worker._decision_task_handler.WorkflowEngine', side_effect=mock_workflow_engine_creator): + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: # Process both tasks await handler.handle_task(task1) await handler.handle_task(task2) - # Verify both engines were created and used - assert len(handler._workflow_engines) == 2 - assert "workflow1:run1" in handler._workflow_engines - assert "workflow2:run2" in handler._workflow_engines + # Verify engines were created for each task + assert mock_engine_class.call_count == 2 - # Verify both engines were called - mock_engine1.process_decision.assert_called_once_with(task1) - mock_engine2.process_decision.assert_called_once_with(task2) + # Verify both tasks were processed + assert mock_engine.process_decision.call_count == 2 @pytest.mark.asyncio - async def test_workflow_engine_cleanup_integration(self, handler, sample_decision_task, mock_registry): - """Test workflow engine cleanup integration.""" + async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): + """Test workflow engine creation integration.""" # Mock workflow function mock_workflow_func = Mock() mock_registry.get_workflow.return_value = mock_workflow_func @@ -202,23 +183,15 @@ async def test_workflow_engine_cleanup_integration(self, handler, sample_decisio mock_engine = Mock(spec=WorkflowEngine) mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: # Process task to create engine await handler.handle_task(sample_decision_task) - # Verify engine was created - assert len(handler._workflow_engines) == 1 - assert "test_workflow_id:test_run_id" in handler._workflow_engines - - # Clean up engine - handler.cleanup_workflow_engine("test_workflow_id", "test_run_id") - - # Verify engine was cleaned up - assert len(handler._workflow_engines) == 0 + # Verify engine was created and used + mock_engine_class.assert_called_once() + mock_engine.process_decision.assert_called_once_with(sample_decision_task) @pytest.mark.asyncio async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): @@ -277,8 +250,6 @@ async def test_concurrent_task_handling(self, handler, mock_registry): mock_engine = Mock(spec=WorkflowEngine) mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): @@ -288,6 +259,3 @@ async def test_concurrent_task_handling(self, handler, mock_registry): # Verify all tasks were processed assert mock_engine.process_decision.call_count == 3 assert handler._client.worker_stub.RespondDecisionTaskCompleted.call_count == 3 - - # Verify engines were created for each workflow - assert len(handler._workflow_engines) == 3 From 27277e41af216550dd15f016c6722e5d3ae8b62c Mon Sep 17 00:00:00 2001 From: Tim Li Date: Tue, 16 Sep 2025 16:19:56 -0700 Subject: [PATCH 10/10] lint Signed-off-by: Tim Li --- cadence/worker/_decision_task_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 5d4c4ae..636505f 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -1,5 +1,4 @@ import logging -from typing import Dict from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import (