diff --git a/cadence/_internal/workflow/context.py b/cadence/_internal/workflow/context.py index 5038e48..87184d8 100644 --- a/cadence/_internal/workflow/context.py +++ b/cadence/_internal/workflow/context.py @@ -1,3 +1,4 @@ +from typing import Optional from cadence.client import Client from cadence.workflow import WorkflowContext, WorkflowInfo @@ -7,9 +8,27 @@ class Context(WorkflowContext): def __init__(self, client: Client, info: WorkflowInfo): self._client = client self._info = info + self._replay_mode = True + self._replay_current_time_milliseconds: Optional[int] = None def info(self) -> WorkflowInfo: return self._info def client(self) -> Client: return self._client + + def set_replay_mode(self, replay: bool) -> None: + """Set whether the workflow is currently in replay mode.""" + self._replay_mode = replay + + def is_replay_mode(self) -> bool: + """Check if the workflow is currently in replay mode.""" + return self._replay_mode + + def set_replay_current_time_milliseconds(self, time_millis: int) -> None: + """Set the current replay time in milliseconds.""" + self._replay_current_time_milliseconds = time_millis + + def get_replay_current_time_milliseconds(self) -> Optional[int]: + """Get the current replay time in milliseconds.""" + return self._replay_current_time_milliseconds diff --git a/cadence/_internal/workflow/decision_events_iterator.py b/cadence/_internal/workflow/decision_events_iterator.py index 0758588..cb0020b 100644 --- a/cadence/_internal/workflow/decision_events_iterator.py +++ b/cadence/_internal/workflow/decision_events_iterator.py @@ -157,10 +157,10 @@ async def next_decision_events(self) -> DecisionEvents: decision_events.events.append(decision_task_started) # Update replay time if available - if hasattr(decision_task_started, 'event_time') and decision_task_started.event_time: - self._replay_current_time_milliseconds = getattr( - decision_task_started.event_time, 'seconds', 0 - ) * 1000 + if decision_task_started.event_time: + self._replay_current_time_milliseconds = ( + decision_task_started.event_time.seconds * 1000 + ) decision_events.replay_current_time_milliseconds = self._replay_current_time_milliseconds # Process subsequent events until we find the corresponding DecisionTask completion diff --git a/cadence/_internal/workflow/decisions_helper.py b/cadence/_internal/workflow/decisions_helper.py new file mode 100644 index 0000000..4099150 --- /dev/null +++ b/cadence/_internal/workflow/decisions_helper.py @@ -0,0 +1,314 @@ +""" +DecisionsHelper manages the next decision ID which is used for tracking decision state machines. + +This helper ensures that decision IDs are properly assigned and tracked to maintain +consistency in the workflow execution state. +""" + +import logging +from dataclasses import dataclass +from typing import Dict, Optional + +from cadence._internal.decision_state_machine import DecisionId, DecisionType, DecisionManager + +logger = logging.getLogger(__name__) + + +@dataclass +class DecisionTracker: + """Tracks a decision with its ID and current state.""" + + decision_id: DecisionId + scheduled_event_id: Optional[int] = None + initiated_event_id: Optional[int] = None + started_event_id: Optional[int] = None + is_completed: bool = False + + +class DecisionsHelper: + """ + Helper class to manage decision IDs and work with DecisionManager state machines. + + This class generates unique decision IDs and integrates with the DecisionManager + state machines for proper decision lifecycle tracking. + """ + + def __init__(self, decision_manager: DecisionManager): + """ + Initialize the DecisionsHelper with a DecisionManager reference. + + Args: + decision_manager: The DecisionManager containing the state machines + """ + self._next_decision_counters: Dict[DecisionType, int] = {} + self._tracked_decisions: Dict[str, DecisionTracker] = {} + self._decision_id_to_key: Dict[str, str] = {} + self._decision_manager = decision_manager + logger.debug("DecisionsHelper initialized with DecisionManager integration") + + def _get_next_counter(self, decision_type: DecisionType) -> int: + """ + Get the next counter value for a given decision type. + + Args: + decision_type: The type of decision + + Returns: + The next counter value + """ + if decision_type not in self._next_decision_counters: + self._next_decision_counters[decision_type] = 1 + else: + self._next_decision_counters[decision_type] += 1 + + return self._next_decision_counters[decision_type] + + def generate_activity_id(self, activity_name: str) -> str: + """ + Generate a unique activity ID. + + Args: + activity_name: The name of the activity + + Returns: + A unique activity ID + """ + counter = self._get_next_counter(DecisionType.ACTIVITY) + activity_id = f"{activity_name}_{counter}" + + # Track this decision + decision_id = DecisionId(DecisionType.ACTIVITY, activity_id) + tracker = DecisionTracker(decision_id) + self._tracked_decisions[activity_id] = tracker + self._decision_id_to_key[str(decision_id)] = activity_id + + logger.debug(f"Generated activity ID: {activity_id}") + return activity_id + + def generate_timer_id(self, timer_name: str = "timer") -> str: + """ + Generate a unique timer ID. + + Args: + timer_name: The name/prefix for the timer + + Returns: + A unique timer ID + """ + counter = self._get_next_counter(DecisionType.TIMER) + timer_id = f"{timer_name}_{counter}" + + # Track this decision + decision_id = DecisionId(DecisionType.TIMER, timer_id) + tracker = DecisionTracker(decision_id) + self._tracked_decisions[timer_id] = tracker + self._decision_id_to_key[str(decision_id)] = timer_id + + logger.debug(f"Generated timer ID: {timer_id}") + return timer_id + + def generate_child_workflow_id(self, workflow_name: str) -> str: + """ + Generate a unique child workflow ID. + + Args: + workflow_name: The name of the child workflow + + Returns: + A unique child workflow ID + """ + counter = self._get_next_counter(DecisionType.CHILD_WORKFLOW) + workflow_id = f"{workflow_name}_{counter}" + + # Track this decision + decision_id = DecisionId(DecisionType.CHILD_WORKFLOW, workflow_id) + tracker = DecisionTracker(decision_id) + self._tracked_decisions[workflow_id] = tracker + self._decision_id_to_key[str(decision_id)] = workflow_id + + logger.debug(f"Generated child workflow ID: {workflow_id}") + return workflow_id + + def generate_marker_id(self, marker_name: str) -> str: + """ + Generate a unique marker ID. + + Args: + marker_name: The name of the marker + + Returns: + A unique marker ID + """ + counter = self._get_next_counter(DecisionType.MARKER) + marker_id = f"{marker_name}_{counter}" + + # Track this decision + decision_id = DecisionId(DecisionType.MARKER, marker_id) + tracker = DecisionTracker(decision_id) + self._tracked_decisions[marker_id] = tracker + self._decision_id_to_key[str(decision_id)] = marker_id + + logger.debug(f"Generated marker ID: {marker_id}") + return marker_id + + def get_decision_tracker(self, decision_key: str) -> Optional[DecisionTracker]: + """ + Get the decision tracker for a given decision key. + + Args: + decision_key: The decision key (activity_id, timer_id, etc.) + + Returns: + The DecisionTracker if found, None otherwise + """ + return self._tracked_decisions.get(decision_key) + + def update_decision_scheduled( + self, decision_key: str, scheduled_event_id: int + ) -> None: + """ + Update a decision tracker when it gets scheduled. + + Args: + decision_key: The decision key + scheduled_event_id: The event ID when the decision was scheduled + """ + tracker = self._tracked_decisions.get(decision_key) + if tracker: + tracker.scheduled_event_id = scheduled_event_id + logger.debug( + f"Updated decision {decision_key} with scheduled event ID {scheduled_event_id}" + ) + else: + logger.warning(f"No tracker found for decision key: {decision_key}") + + def update_decision_initiated( + self, decision_key: str, initiated_event_id: int + ) -> None: + """ + Update a decision tracker when it gets initiated. + + Args: + decision_key: The decision key + initiated_event_id: The event ID when the decision was initiated + """ + tracker = self._tracked_decisions.get(decision_key) + if tracker: + tracker.initiated_event_id = initiated_event_id + logger.debug( + f"Updated decision {decision_key} with initiated event ID {initiated_event_id}" + ) + else: + logger.warning(f"No tracker found for decision key: {decision_key}") + + def update_decision_started(self, decision_key: str, started_event_id: int) -> None: + """ + Update a decision tracker when it gets started. + + Args: + decision_key: The decision key + started_event_id: The event ID when the decision was started + """ + tracker = self._tracked_decisions.get(decision_key) + if tracker: + tracker.started_event_id = started_event_id + logger.debug( + f"Updated decision {decision_key} with started event ID {started_event_id}" + ) + else: + logger.warning(f"No tracker found for decision key: {decision_key}") + + def update_decision_completed(self, decision_key: str) -> None: + """ + Mark a decision as completed. + + Args: + decision_key: The decision key + """ + tracker = self._tracked_decisions.get(decision_key) + if tracker: + tracker.is_completed = True + logger.debug(f"Marked decision {decision_key} as completed") + else: + logger.warning(f"No tracker found for decision key: {decision_key}") + + + def _find_decision_by_scheduled_event_id( + self, scheduled_event_id: int + ) -> Optional[str]: + """Find a decision key by its scheduled event ID.""" + for key, tracker in self._tracked_decisions.items(): + if tracker.scheduled_event_id == scheduled_event_id: + return key + return None + + def _find_decision_by_initiated_event_id( + self, initiated_event_id: int + ) -> Optional[str]: + """Find a decision key by its initiated event ID.""" + for key, tracker in self._tracked_decisions.items(): + if tracker.initiated_event_id == initiated_event_id: + return key + return None + + def _find_decision_by_started_event_id( + self, started_event_id: int + ) -> Optional[str]: + """Find a decision key by its started event ID.""" + for key, tracker in self._tracked_decisions.items(): + if tracker.started_event_id == started_event_id: + return key + return None + + def get_pending_decisions_count(self) -> int: + """ + Get the count of decisions that are not yet completed. + + Returns: + The number of pending decisions + """ + return sum( + 1 + for tracker in self._tracked_decisions.values() + if not tracker.is_completed + ) + + def get_completed_decisions_count(self) -> int: + """ + Get the count of decisions that have been completed. + + Returns: + The number of completed decisions + """ + return sum( + 1 for tracker in self._tracked_decisions.values() if tracker.is_completed + ) + + def reset(self) -> None: + """Reset all decision tracking state.""" + self._next_decision_counters.clear() + self._tracked_decisions.clear() + self._decision_id_to_key.clear() + logger.debug("DecisionsHelper reset") + + def get_stats(self) -> Dict[str, int]: + """ + Get statistics about tracked decisions. + + Returns: + Dictionary with decision statistics + """ + stats = { + "total_decisions": len(self._tracked_decisions), + "pending_decisions": self.get_pending_decisions_count(), + "completed_decisions": self.get_completed_decisions_count(), + } + + # Add per-type counts + for decision_type in DecisionType: + type_name = decision_type.name.lower() + stats[f"{type_name}_count"] = self._next_decision_counters.get( + decision_type, 0 + ) + + return stats diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 00fac2c..2456cc1 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -1,11 +1,18 @@ +import asyncio +import logging from dataclasses import dataclass -from typing import Optional, Callable, Any +from typing import Callable, Any from cadence._internal.workflow.context import Context +from cadence._internal.workflow.decisions_helper import DecisionsHelper +from cadence._internal.workflow.decision_events_iterator import DecisionEventsIterator from cadence.api.v1.decision_pb2 import Decision from cadence.client import Client from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.workflow import WorkflowInfo +from cadence._internal.decision_state_machine import DecisionManager + +logger = logging.getLogger(__name__) @dataclass @@ -13,11 +20,367 @@ class DecisionResult: decisions: list[Decision] class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Optional[Callable[..., Any]] = None): + def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Callable[[Any], Any] | None = None): self._context = Context(client, info) self._workflow_func = workflow_func + self._decision_manager = DecisionManager() + self._decisions_helper = DecisionsHelper(self._decision_manager) + self._is_workflow_complete = False - # TODO: Implement this async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult: - with self._context._activate(): - return DecisionResult(decisions=[]) + """ + Process a decision task and generate decisions using DecisionEventsIterator. + + This method follows the Java client pattern of using DecisionEventsIterator + to drive the decision processing pipeline with proper replay handling. + + Args: + decision_task: The PollForDecisionTaskResponse from the service + + Returns: + DecisionResult containing the list of decisions + """ + try: + # Log decision task processing start with full context (matches Java ReplayDecisionTaskHandler) + logger.info( + "Processing decision task for workflow", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id, + "started_event_id": decision_task.started_event_id, + "attempt": decision_task.attempt + } + ) + + # Activate workflow context for the entire decision processing + with self._context._activate(): + # Create DecisionEventsIterator for structured event processing + events_iterator = DecisionEventsIterator(decision_task, self._context.client()) + + # Process decision events using iterator-driven approach + await self._process_decision_events(events_iterator, decision_task) + + # Collect all pending decisions from state machines + decisions = self._decision_manager.collect_pending_decisions() + + # Close decider's event loop + self._close_event_loop() + + # Log decision task completion with metrics (matches Java ReplayDecisionTaskHandler) + logger.debug( + "Decision task completed", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id, + "started_event_id": decision_task.started_event_id, + "decisions_count": len(decisions), + "replay_mode": self._context.is_replay_mode() + } + ) + + return DecisionResult(decisions=decisions) + + except Exception as e: + # Log decision task failure with full context (matches Java ReplayDecisionTaskHandler) + logger.error( + "Decision task processing failed", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id, + "started_event_id": decision_task.started_event_id, + "attempt": decision_task.attempt, + "error_type": type(e).__name__ + }, + exc_info=True + ) + # Re-raise the exception so the handler can properly handle the failure + raise + + async def _process_decision_events(self, events_iterator: DecisionEventsIterator, decision_task: PollForDecisionTaskResponse) -> None: + """ + Process decision events using the iterator-driven approach similar to Java client. + + This method implements the three-phase event processing pattern: + 1. Process markers first (for deterministic replay) + 2. Process regular events (trigger workflow state changes) + 3. Execute workflow logic + 4. Process decision events from previous decisions + + Args: + events_iterator: The DecisionEventsIterator for structured event processing + decision_task: The original decision task + """ + # Track if we processed any decision events + processed_any_decision_events = False + + # Check if there are any decision events to process + while await events_iterator.has_next_decision_events(): + decision_events = await events_iterator.next_decision_events() + processed_any_decision_events = True + + # Log decision events batch processing (matches Go client patterns) + logger.debug( + "Processing decision events batch", + extra={ + "workflow_id": self._context.info().workflow_id, + "events_count": len(decision_events.get_events()), + "markers_count": len(decision_events.get_markers()), + "replay_mode": decision_events.is_replay(), + "replay_time": decision_events.replay_current_time_milliseconds + } + ) + + # Update context with replay information + self._context.set_replay_mode(decision_events.is_replay()) + if decision_events.replay_current_time_milliseconds: + self._context.set_replay_current_time_milliseconds(decision_events.replay_current_time_milliseconds) + + # Phase 1: Process markers first for deterministic replay + for marker_event in decision_events.get_markers(): + try: + logger.debug( + "Processing marker event", + extra={ + "workflow_id": self._context.info().workflow_id, + "marker_name": getattr(marker_event, 'marker_name', 'unknown'), + "event_id": getattr(marker_event, 'event_id', None), + "replay_mode": self._context.is_replay_mode() + } + ) + # Process through state machines (DecisionsHelper now delegates to DecisionManager) + self._decision_manager.handle_history_event(marker_event) + except Exception as e: + # Warning for unexpected markers (matches Java ClockDecisionContext) + logger.warning( + "Unexpected marker event encountered", + extra={ + "workflow_id": self._context.info().workflow_id, + "marker_name": getattr(marker_event, 'marker_name', 'unknown'), + "event_id": getattr(marker_event, 'event_id', None), + "error_type": type(e).__name__ + }, + exc_info=True + ) + + # Phase 2: Process regular events to update workflow state + for event in decision_events.get_events(): + try: + logger.debug( + "Processing history event", + extra={ + "workflow_id": self._context.info().workflow_id, + "event_type": getattr(event, 'event_type', 'unknown'), + "event_id": getattr(event, 'event_id', None), + "replay_mode": self._context.is_replay_mode() + } + ) + # Process through state machines (DecisionsHelper now delegates to DecisionManager) + self._decision_manager.handle_history_event(event) + except Exception as e: + logger.warning( + "Error processing history event", + extra={ + "workflow_id": self._context.info().workflow_id, + "event_type": getattr(event, 'event_type', 'unknown'), + "event_id": getattr(event, 'event_id', None), + "error_type": type(e).__name__ + }, + exc_info=True + ) + + # Phase 3: Execute workflow logic if not in replay mode + if not decision_events.is_replay() and not self._is_workflow_complete: + await self._execute_workflow_function(decision_task) + + # If no decision events were processed but we have history, fall back to direct processing + # This handles edge cases where the iterator doesn't find decision events + if not processed_any_decision_events and decision_task.history and hasattr(decision_task.history, 'events'): + logger.debug( + "No decision events found by iterator, falling back to direct history processing", + extra={ + "workflow_id": self._context.info().workflow_id, + "history_events_count": len(decision_task.history.events) if decision_task.history else 0 + } + ) + self._fallback_process_workflow_history(decision_task.history) + if not self._is_workflow_complete: + await self._execute_workflow_function(decision_task) + + + def _fallback_process_workflow_history(self, history) -> None: + """ + Fallback method to process workflow history events directly. + + This is used when DecisionEventsIterator doesn't find decision events, + maintaining backward compatibility. + + Args: + history: The workflow history from the decision task + """ + if not history or not hasattr(history, 'events'): + return + + logger.debug( + "Processing history events in fallback mode", + extra={ + "workflow_id": self._context.info().workflow_id, + "events_count": len(history.events) + } + ) + + for event in history.events: + try: + # Process through state machines (DecisionsHelper now delegates to DecisionManager) + self._decision_manager.handle_history_event(event) + except Exception as e: + logger.warning( + "Error processing history event in fallback mode", + extra={ + "workflow_id": self._context.info().workflow_id, + "event_type": getattr(event, 'event_type', 'unknown'), + "event_id": getattr(event, 'event_id', None), + "error_type": type(e).__name__ + }, + exc_info=True + ) + + async def _execute_workflow_function(self, decision_task: PollForDecisionTaskResponse) -> None: + """ + Execute the workflow function to generate new decisions. + + This blocks until the workflow schedules an activity or completes. + + Args: + decision_task: The decision task containing workflow context + """ + try: + # Execute the workflow function + # The workflow function should block until it schedules an activity + workflow_func = self._workflow_func + if workflow_func is None: + logger.warning( + "No workflow function available", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id + } + ) + return + + # Extract workflow input from history + workflow_input = await self._extract_workflow_input(decision_task) + + # Execute workflow function + result = self._execute_workflow_function_once(workflow_func, workflow_input) + + # Check if workflow is complete + if result is not None: + self._is_workflow_complete = True + # Log workflow completion (matches Go client patterns) + logger.info( + "Workflow execution completed", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id, + "completion_type": "success" + } + ) + + except Exception as e: + logger.error( + "Error executing workflow function", + extra={ + "workflow_type": self._context.info().workflow_type, + "workflow_id": self._context.info().workflow_id, + "run_id": self._context.info().workflow_run_id, + "error_type": type(e).__name__ + }, + exc_info=True + ) + raise + + async def _extract_workflow_input(self, decision_task: PollForDecisionTaskResponse) -> Any: + """ + Extract workflow input from the decision task history. + + Args: + decision_task: The decision task containing workflow history + + Returns: + The workflow input data, or None if not found + """ + if not decision_task.history or not hasattr(decision_task.history, 'events'): + logger.warning("No history events found in decision task") + return None + + # Look for WorkflowExecutionStarted event + for event in decision_task.history.events: + if hasattr(event, 'workflow_execution_started_event_attributes'): + started_attrs = event.workflow_execution_started_event_attributes + if started_attrs and hasattr(started_attrs, 'input'): + # Deserialize the input using the client's data converter + try: + # Use from_data method with a single type hint of None (no type conversion) + input_data_list = await self._context.client().data_converter.from_data(started_attrs.input, [None]) + input_data = input_data_list[0] if input_data_list else None + logger.debug(f"Extracted workflow input: {input_data}") + return input_data + except Exception as e: + logger.warning(f"Failed to deserialize workflow input: {e}") + return None + + logger.warning("No WorkflowExecutionStarted event found in history") + return None + + def _execute_workflow_function_once(self, workflow_func: Callable, workflow_input: Any) -> Any: + """ + Execute the workflow function once (not during replay). + + Args: + workflow_func: The workflow function to execute + workflow_input: The input data for the workflow function + + Returns: + The result of the workflow function execution + """ + logger.debug(f"Executing workflow function with input: {workflow_input}") + result = workflow_func(workflow_input) + + # If the workflow function is async, we need to handle it properly + if asyncio.iscoroutine(result): + # For now, use asyncio.run for async workflow functions + # TODO: Implement proper deterministic event loop for workflow execution + try: + result = asyncio.run(result) + except RuntimeError: + # If we're already in an event loop, create a new task + loop = asyncio.get_event_loop() + if loop.is_running(): + # We can't use asyncio.run inside a running loop + # For now, just get the result (this may not be deterministic) + logger.warning("Async workflow function called within running event loop - may not be deterministic") + # This is a workaround - in a real implementation, we'd need proper task scheduling + result = None + else: + result = loop.run_until_complete(result) + + return result + + def _close_event_loop(self) -> None: + """ + Close the decider's event loop. + """ + try: + # Get the current event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # Schedule the loop to stop + loop.call_soon_threadsafe(loop.stop) + logger.debug("Scheduled event loop to stop") + except Exception as e: + logger.warning(f"Error closing event loop: {e}") diff --git a/cadence/worker/_decision.py b/cadence/worker/_decision.py index 47e0817..64f31c7 100644 --- a/cadence/worker/_decision.py +++ b/cadence/worker/_decision.py @@ -1,46 +1,58 @@ import asyncio from typing import Optional -from cadence.api.v1.common_pb2 import Payload -from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskRequest, PollForDecisionTaskResponse, \ - RespondDecisionTaskFailedRequest +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskRequest, + PollForDecisionTaskResponse, +) from cadence.api.v1.tasklist_pb2 import TaskList, TaskListKind -from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause from cadence.client import Client +from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._poller import Poller -from cadence.worker._types import WorkerOptions, _LONG_POLL_TIMEOUT +from cadence.worker._registry import Registry +from cadence.worker._types import _LONG_POLL_TIMEOUT, WorkerOptions class DecisionWorker: - def __init__(self, client: Client, task_list: str, options: WorkerOptions) -> None: + def __init__( + self, client: Client, task_list: str, registry: Registry, options: WorkerOptions + ) -> None: self._client = client self._task_list = task_list + self._registry = registry self._identity = options["identity"] - permits = asyncio.Semaphore(options["max_concurrent_decision_task_execution_size"]) - self._poller = Poller[PollForDecisionTaskResponse](options["decision_task_pollers"], permits, self._poll, self._execute) + permits = asyncio.Semaphore( + options["max_concurrent_decision_task_execution_size"] + ) + self._decision_handler = DecisionTaskHandler( + client, task_list, registry, **options + ) + self._poller = Poller[PollForDecisionTaskResponse]( + options["decision_task_pollers"], permits, self._poll, self._execute + ) # TODO: Sticky poller, actually running workflows, etc. async def run(self) -> None: await self._poller.run() async def _poll(self) -> Optional[PollForDecisionTaskResponse]: - task: PollForDecisionTaskResponse = await self._client.worker_stub.PollForDecisionTask(PollForDecisionTaskRequest( - domain=self._client.domain, - task_list=TaskList(name=self._task_list,kind=TaskListKind.TASK_LIST_KIND_NORMAL), - identity=self._identity, - ), timeout=_LONG_POLL_TIMEOUT) - - if task.task_token: + task: PollForDecisionTaskResponse = ( + await self._client.worker_stub.PollForDecisionTask( + PollForDecisionTaskRequest( + domain=self._client.domain, + task_list=TaskList( + name=self._task_list, kind=TaskListKind.TASK_LIST_KIND_NORMAL + ), + identity=self._identity, + ), + timeout=_LONG_POLL_TIMEOUT, + ) + ) + + if task and task.task_token: return task else: return None - async def _execute(self, task: PollForDecisionTaskResponse) -> None: - await self._client.worker_stub.RespondDecisionTaskFailed(RespondDecisionTaskFailedRequest( - task_token=task.task_token, - cause=DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION, - identity=self._identity, - details=Payload(data=b'not implemented') - )) - + await self._decision_handler.handle_task(task) diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index 636505f..d35ee66 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -1,4 +1,6 @@ import logging +import threading +from typing import Dict, Tuple from cadence.api.v1.common_pb2 import Payload from cadence.api.v1.service_worker_pb2 import ( @@ -19,7 +21,8 @@ class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]): """ Task handler for processing decision tasks. - This handler processes decision tasks and generates decisions using the workflow engine. + This handler processes decision tasks and generates decisions using workflow engines. + Uses a thread-safe cache to hold workflow engines for concurrent decision task handling. """ def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options): @@ -35,7 +38,9 @@ def __init__(self, client: Client, task_list: str, registry: Registry, identity: """ super().__init__(client, task_list, identity, **options) self._registry = registry - self._workflow_engine: WorkflowEngine + # Thread-safe cache to hold workflow engines keyed by (workflow_id, run_id) + self._workflow_engines: Dict[Tuple[str, str], WorkflowEngine] = {} + self._cache_lock = threading.RLock() async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None: @@ -57,15 +62,34 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - run_id = workflow_execution.run_id workflow_type_name = workflow_type.name - logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})") + # This log matches the WorkflowEngine but at task handler level (like Java ReplayDecisionTaskHandler) + logger.info( + "Received decision task for workflow", + extra={ + "workflow_type": workflow_type_name, + "workflow_id": workflow_id, + "run_id": run_id, + "started_event_id": task.started_event_id, + "attempt": task.attempt, + "task_token": task.task_token[:16].hex() if task.task_token else None # Log partial token for debugging + } + ) try: workflow_func = self._registry.get_workflow(workflow_type_name) except KeyError: - logger.error(f"Workflow type '{workflow_type_name}' not found in registry") + logger.error( + "Workflow type not found in registry", + extra={ + "workflow_type": workflow_type_name, + "workflow_id": workflow_id, + "run_id": run_id, + "error_type": "workflow_not_registered" + } + ) raise KeyError(f"Workflow type '{workflow_type_name}' not found") - # Create workflow info and engine + # Create workflow info and get or create workflow engine from cache workflow_info = WorkflowInfo( workflow_type=workflow_type_name, workflow_domain=self._client.domain, @@ -73,18 +97,45 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_run_id=run_id ) - self._workflow_engine = WorkflowEngine( - info=workflow_info, - client=self._client, - workflow_func=workflow_func - ) + # Use thread-safe cache to get or create workflow engine + cache_key = (workflow_id, run_id) + with self._cache_lock: + workflow_engine = self._workflow_engines.get(cache_key) + if workflow_engine is None: + workflow_engine = WorkflowEngine( + info=workflow_info, + client=self._client, + workflow_func=workflow_func + ) + self._workflow_engines[cache_key] = workflow_engine - decision_result = await self._workflow_engine.process_decision(task) + decision_result = await workflow_engine.process_decision(task) + + # Clean up completed workflows from cache to prevent memory leaks + if workflow_engine._is_workflow_complete: + with self._cache_lock: + self._workflow_engines.pop(cache_key, None) + logger.debug( + "Removed completed workflow from cache", + extra={ + "workflow_id": workflow_id, + "run_id": run_id, + "cache_size": len(self._workflow_engines) + } + ) # Respond with the decisions await self._respond_decision_task_completed(task, decision_result) - logger.info(f"Successfully processed decision task for workflow {workflow_id}") + logger.info( + "Successfully processed decision task", + extra={ + "workflow_type": workflow_type_name, + "workflow_id": workflow_id, + "run_id": run_id, + "started_event_id": task.started_event_id + } + ) async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None: """ @@ -94,7 +145,26 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex task: The task that failed error: The exception that occurred """ - logger.error(f"Decision task failed: {error}") + # Extract workflow context for error logging (matches Java ReplayDecisionTaskHandler error patterns) + workflow_execution = task.workflow_execution + workflow_id = workflow_execution.workflow_id if workflow_execution else "unknown" + run_id = workflow_execution.run_id if workflow_execution else "unknown" + workflow_type = task.workflow_type.name if task.workflow_type else "unknown" + + # Log task failure with full context (matches Java error logging) + logger.error( + "Decision task processing failure", + extra={ + "workflow_type": workflow_type, + "workflow_id": workflow_id, + "run_id": run_id, + "started_event_id": task.started_event_id, + "attempt": task.attempt, + "error_type": type(error).__name__, + "error_message": str(error) + }, + exc_info=True + ) # Determine the failure cause cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION @@ -118,10 +188,26 @@ async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Ex details=details ) ) - logger.info("Decision task failure response sent") - except Exception: - logger.exception("Error responding to decision task failure") - + logger.info( + "Decision task failure response sent", + extra={ + "workflow_id": workflow_id, + "run_id": run_id, + "cause": cause, + "task_token": task.task_token[:16].hex() if task.task_token else None + } + ) + except Exception as e: + logger.error( + "Error responding to decision task failure", + extra={ + "workflow_id": workflow_id, + "run_id": run_id, + "original_error": type(error).__name__, + "response_error": type(e).__name__ + }, + exc_info=True + ) async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None: """ @@ -136,13 +222,38 @@ async def _respond_decision_task_completed(self, task: PollForDecisionTaskRespon task_token=task.task_token, decisions=decision_result.decisions, identity=self._identity, - return_new_decision_task=True, - force_create_new_decision_task=False + return_new_decision_task=True ) await self._client.worker_stub.RespondDecisionTaskCompleted(request) - logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions") + + # Log completion response (matches Java ReplayDecisionTaskHandler trace/debug patterns) + workflow_execution = task.workflow_execution + logger.debug( + "Decision task completion response sent", + extra={ + "workflow_type": task.workflow_type.name if task.workflow_type else "unknown", + "workflow_id": workflow_execution.workflow_id if workflow_execution else "unknown", + "run_id": workflow_execution.run_id if workflow_execution else "unknown", + "started_event_id": task.started_event_id, + "decisions_count": len(decision_result.decisions), + "return_new_decision_task": True, + "task_token": task.task_token[:16].hex() if task.task_token else None + } + ) - except Exception: - logger.exception("Error responding to decision task completion") + except Exception as e: + workflow_execution = task.workflow_execution + logger.error( + "Error responding to decision task completion", + extra={ + "workflow_type": task.workflow_type.name if task.workflow_type else "unknown", + "workflow_id": workflow_execution.workflow_id if workflow_execution else "unknown", + "run_id": workflow_execution.run_id if workflow_execution else "unknown", + "started_event_id": task.started_event_id, + "decisions_count": len(decision_result.decisions), + "error_type": type(e).__name__ + }, + exc_info=True + ) raise diff --git a/cadence/worker/_worker.py b/cadence/worker/_worker.py index 70ce364..ff273ad 100644 --- a/cadence/worker/_worker.py +++ b/cadence/worker/_worker.py @@ -19,7 +19,7 @@ def __init__(self, client: Client, task_list: str, registry: Registry, **kwargs: _validate_and_copy_defaults(client, task_list, options) self._options = options self._activity_worker = ActivityWorker(client, task_list, registry, options) - self._decision_worker = DecisionWorker(client, task_list, options) + self._decision_worker = DecisionWorker(client, task_list, registry, options) async def run(self) -> None: diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py new file mode 100644 index 0000000..cb1f449 --- /dev/null +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +""" +Integration tests for WorkflowEngine. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType +from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult +from cadence.workflow import WorkflowInfo +from cadence.client import Client + + +class TestWorkflowEngineIntegration: + """Integration tests for WorkflowEngine.""" + + @pytest.fixture + def mock_client(self): + """Create a mock Cadence client.""" + client = Mock(spec=Client) + client.domain = "test-domain" + client.data_converter = Mock() + client.data_converter.from_data = AsyncMock(return_value=["test-input"]) + return client + + @pytest.fixture + def workflow_info(self): + """Create workflow info.""" + return WorkflowInfo( + workflow_type="test_workflow", + workflow_domain="test-domain", + workflow_id="test-workflow-id", + workflow_run_id="test-run-id" + ) + + @pytest.fixture + def mock_workflow_func(self): + """Create a mock workflow function.""" + def workflow_func(input_data): + return f"processed: {input_data}" + return workflow_func + + @pytest.fixture + def workflow_engine(self, mock_client, workflow_info, mock_workflow_func): + """Create a WorkflowEngine instance.""" + return WorkflowEngine( + info=workflow_info, + client=mock_client, + workflow_func=mock_workflow_func + ) + + def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + """Create a mock decision task with history.""" + # Create workflow execution + workflow_execution = WorkflowExecution() + workflow_execution.workflow_id = workflow_id + workflow_execution.run_id = run_id + + # Create workflow type + workflow_type_obj = WorkflowType() + workflow_type_obj.name = workflow_type + + # Create workflow execution started event + started_event = WorkflowExecutionStartedEventAttributes() + input_payload = Payload(data=b'"test-input"') + started_event.input.CopyFrom(input_payload) + + history_event = HistoryEvent() + history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) + + # Create history + history = History() + history.events.append(history_event) + + # Create decision task + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + decision_task.workflow_execution.CopyFrom(workflow_execution) + decision_task.workflow_type.CopyFrom(workflow_type_obj) + decision_task.history.CopyFrom(history) + + return decision_task + + @pytest.mark.asyncio + async def test_process_decision_success(self, workflow_engine, mock_client): + """Test successful decision processing.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision manager to return some decisions + with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[Mock()]): + # Process the decision + result = await workflow_engine.process_decision(decision_task) + + # Verify the result + assert isinstance(result, DecisionResult) + assert len(result.decisions) == 1 + + @pytest.mark.asyncio + async def test_process_decision_with_history(self, workflow_engine, mock_client): + """Test decision processing with history events.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision manager + with patch.object(workflow_engine._decision_manager, 'handle_history_event') as mock_handle: + with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + # Process the decision + await workflow_engine.process_decision(decision_task) + + # Verify history events were processed + mock_handle.assert_called() + + @pytest.mark.asyncio + async def test_process_decision_workflow_complete(self, workflow_engine, mock_client): + """Test decision processing when workflow is already complete.""" + # Mark workflow as complete + workflow_engine._is_workflow_complete = True + + decision_task = self.create_mock_decision_task() + + with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + # Process the decision + result = await workflow_engine.process_decision(decision_task) + + # Verify the result + assert isinstance(result, DecisionResult) + assert len(result.decisions) == 0 + + @pytest.mark.asyncio + async def test_process_decision_error_handling(self, workflow_engine, mock_client): + """Test decision processing error handling.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision manager to raise an exception + with patch.object(workflow_engine._decision_manager, 'handle_history_event', side_effect=Exception("Test error")): + # Process the decision + result = await workflow_engine.process_decision(decision_task) + + # Verify error handling - should return empty decisions + assert isinstance(result, DecisionResult) + assert len(result.decisions) == 0 + + @pytest.mark.asyncio + async def test_extract_workflow_input_success(self, workflow_engine, mock_client): + """Test successful workflow input extraction.""" + decision_task = self.create_mock_decision_task() + + # Extract workflow input + input_data = await workflow_engine._extract_workflow_input(decision_task) + + # Verify the input was extracted + assert input_data == "test-input" + mock_client.data_converter.from_data.assert_called_once() + + @pytest.mark.asyncio + async def test_extract_workflow_input_no_history(self, workflow_engine, mock_client): + """Test workflow input extraction with no history.""" + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + # No history set + + # Extract workflow input + input_data = await workflow_engine._extract_workflow_input(decision_task) + + # Verify no input was extracted + assert input_data is None + + @pytest.mark.asyncio + async def test_extract_workflow_input_no_started_event(self, workflow_engine, mock_client): + """Test workflow input extraction with no WorkflowExecutionStarted event.""" + # Create a decision task with no started event + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + + # Create workflow execution + workflow_execution = WorkflowExecution() + workflow_execution.workflow_id = "test-workflow" + workflow_execution.run_id = "test-run" + decision_task.workflow_execution.CopyFrom(workflow_execution) + + # Create workflow type + workflow_type_obj = WorkflowType() + workflow_type_obj.name = "test_workflow" + decision_task.workflow_type.CopyFrom(workflow_type_obj) + + # Create history with no events + history = History() + decision_task.history.CopyFrom(history) + + # Extract workflow input + input_data = await workflow_engine._extract_workflow_input(decision_task) + + # Verify no input was extracted + assert input_data is None + + @pytest.mark.asyncio + async def test_extract_workflow_input_deserialization_error(self, workflow_engine, mock_client): + """Test workflow input extraction with deserialization error.""" + decision_task = self.create_mock_decision_task() + + # Mock data converter to raise an exception + mock_client.data_converter.from_data = AsyncMock(side_effect=Exception("Deserialization error")) + + # Extract workflow input + input_data = await workflow_engine._extract_workflow_input(decision_task) + + # Verify no input was extracted due to error + assert input_data is None + + def test_execute_workflow_function_sync(self, workflow_engine): + """Test synchronous workflow function execution.""" + input_data = "test-input" + + # Execute the workflow function + result = workflow_engine._execute_workflow_function_once(workflow_engine._workflow_func, input_data) + + # Verify the result + assert result == "processed: test-input" + + def test_execute_workflow_function_async(self, workflow_engine): + """Test asynchronous workflow function execution.""" + async def async_workflow_func(input_data): + return f"async-processed: {input_data}" + + input_data = "test-input" + + # Execute the async workflow function + result = workflow_engine._execute_workflow_function_once(async_workflow_func, input_data) + + # Verify the result + assert result == "async-processed: test-input" + + def test_execute_workflow_function_none(self, workflow_engine): + """Test workflow function execution with None function.""" + input_data = "test-input" + + # Execute with None workflow function - should raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not callable"): + workflow_engine._execute_workflow_function_once(None, input_data) + + def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_func): + """Test WorkflowEngine initialization.""" + assert workflow_engine._context is not None + assert workflow_engine._workflow_func == mock_workflow_func + assert workflow_engine._decision_manager is not None + assert workflow_engine._is_workflow_complete is False + + @pytest.mark.asyncio + async def test_workflow_engine_without_workflow_func(self, mock_client, workflow_info): + """Test WorkflowEngine without workflow function.""" + engine = WorkflowEngine( + info=workflow_info, + client=mock_client, + workflow_func=None + ) + + decision_task = self.create_mock_decision_task() + + with patch.object(engine._decision_manager, 'collect_pending_decisions', return_value=[]): + # Process the decision + result = await engine.process_decision(decision_task) + + # Verify the result + assert isinstance(result, DecisionResult) + assert len(result.decisions) == 0 + + @pytest.mark.asyncio + async def test_workflow_engine_workflow_completion(self, workflow_engine, mock_client): + """Test workflow completion detection.""" + decision_task = self.create_mock_decision_task() + + # Mock workflow function to return a result (indicating completion) + def completing_workflow_func(input_data): + return "workflow-completed" + + workflow_engine._workflow_func = completing_workflow_func + + with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): + # Process the decision + await workflow_engine.process_decision(decision_task) + + # Verify workflow is marked as complete + assert workflow_engine._is_workflow_complete is True + + def test_close_event_loop(self, workflow_engine): + """Test event loop closing.""" + # This should not raise an exception + workflow_engine._close_event_loop() + + @pytest.mark.asyncio + async def test_process_decision_with_query_results(self, workflow_engine, mock_client): + """Test decision processing with query results.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision manager to return decisions with query results + mock_decisions = [Mock()] + + with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=mock_decisions): + # Process the decision + result = await workflow_engine.process_decision(decision_task) + + # Verify the result + assert isinstance(result, DecisionResult) + assert len(result.decisions) == 1 + # Not set in this test diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index 2fc98ec..cd2b210 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -58,6 +58,9 @@ def sample_decision_task(self): task.workflow_execution.run_id = "test_run_id" task.workflow_type = Mock() task.workflow_type.name = "TestWorkflow" + # Add the missing attributes that are now accessed directly + task.started_event_id = 1 + task.attempt = 1 return task def test_initialization(self, mock_client, mock_registry): @@ -85,10 +88,10 @@ async def test_handle_task_implementation_success(self, handler, sample_decision # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [Decision()] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): @@ -137,32 +140,82 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp await handler._handle_task_implementation(sample_decision_task) @pytest.mark.asyncio - async def test_handle_task_implementation_creates_new_engine(self, handler, sample_decision_task, mock_registry): - """Test that decision task handler creates new workflow engine for each task.""" + async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry): + """Test that decision task handler caches workflow engines for same workflow execution.""" # Mock workflow function mock_workflow_func = Mock() mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: # First call - should create new engine await handler._handle_task_implementation(sample_decision_task) - # Second call - should create another new engine + # Second call with same workflow_id and run_id - should reuse cached engine await handler._handle_task_implementation(sample_decision_task) + # Registry should be called for each task (to get workflow function) + assert mock_registry.get_workflow.call_count == 2 + + # Engine should be created only once (cached for second call) + assert mock_engine_class.call_count == 1 + + # But process_decision should be called twice + assert mock_engine.process_decision.call_count == 2 + + @pytest.mark.asyncio + async def test_handle_task_implementation_different_executions_get_separate_engines(self, handler, mock_registry): + """Test that different workflow executions get separate engines.""" + # Mock workflow function + mock_workflow_func = Mock() + mock_registry.get_workflow.return_value = mock_workflow_func + + # Create two different decision tasks + task1 = Mock(spec=PollForDecisionTaskResponse) + task1.task_token = b"test_task_token_1" + task1.workflow_execution = Mock() + task1.workflow_execution.workflow_id = "workflow_1" + task1.workflow_execution.run_id = "run_1" + task1.workflow_type = Mock() + task1.workflow_type.name = "TestWorkflow" + task1.started_event_id = 1 + task1.attempt = 1 + + task2 = Mock(spec=PollForDecisionTaskResponse) + task2.task_token = b"test_task_token_2" + task2.workflow_execution = Mock() + task2.workflow_execution.workflow_id = "workflow_2" # Different workflow + task2.workflow_execution.run_id = "run_2" # Different run + task2.workflow_type = Mock() + task2.workflow_type.name = "TestWorkflow" + task2.started_event_id = 2 + task2.attempt = 1 + + # Mock workflow engine + mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute + mock_decision_result = Mock(spec=DecisionResult) + mock_decision_result.decisions = [] + mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_engine_class: + # Process different workflow executions + await handler._handle_task_implementation(task1) + await handler._handle_task_implementation(task2) + # Registry should be called for each task assert mock_registry.get_workflow.call_count == 2 - # Engine should be created twice and called twice + # Engine should be created twice (different executions) assert mock_engine_class.call_count == 2 + + # Process_decision should be called twice assert mock_engine.process_decision.call_count == 2 @pytest.mark.asyncio @@ -224,7 +277,8 @@ async def test_handle_task_failure_respond_error(self, handler, sample_decision_ # Should not raise exception, but should log error with patch('cadence.worker._decision_task_handler.logger') as mock_logger: await handler.handle_task_failure(sample_decision_task, error) - mock_logger.exception.assert_called_once() + # Now uses logger.error with exc_info=True instead of logger.exception + mock_logger.error.assert_called() @pytest.mark.asyncio async def test_respond_decision_task_completed_success(self, handler, sample_decision_task): @@ -240,7 +294,6 @@ async def test_respond_decision_task_completed_success(self, handler, sample_dec assert call_args.task_token == sample_decision_task.task_token assert call_args.identity == handler._identity assert call_args.return_new_decision_task - assert not call_args.force_create_new_decision_task assert len(call_args.decisions) == 2 @pytest.mark.asyncio @@ -253,7 +306,6 @@ async def test_respond_decision_task_completed_no_query_results(self, handler, s call_args = handler._client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] assert call_args.return_new_decision_task - assert not call_args.force_create_new_decision_task assert len(call_args.decisions) == 0 @pytest.mark.asyncio @@ -275,10 +327,9 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample mock_registry.get_workflow.return_value = mock_workflow_func mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] - mock_decision_result.force_create_new_decision_task = False - mock_decision_result.query_results = {} mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: diff --git a/tests/cadence/worker/test_decision_task_handler_integration.py b/tests/cadence/worker/test_decision_task_handler_integration.py new file mode 100644 index 0000000..b513a14 --- /dev/null +++ b/tests/cadence/worker/test_decision_task_handler_integration.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +""" +Integration tests for DecisionTaskHandler and WorkflowEngine. +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch +from cadence.api.v1.service_worker_pb2 import ( + PollForDecisionTaskResponse +) +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType +from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence.api.v1.decision_pb2 import Decision +from cadence.worker._decision_task_handler import DecisionTaskHandler +from cadence.worker._registry import Registry +from cadence.client import Client + + +class TestDecisionTaskHandlerIntegration: + """Integration tests for DecisionTaskHandler.""" + + @pytest.fixture + def mock_client(self): + """Create a mock Cadence client.""" + client = Mock(spec=Client) + client.domain = "test-domain" + client.data_converter = Mock() + client.data_converter.from_data = AsyncMock(return_value=["test-input"]) + client.worker_stub = Mock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + return client + + @pytest.fixture + def registry(self): + """Create a registry with a test workflow.""" + reg = Registry() + + @reg.workflow + def test_workflow(input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + + return reg + + @pytest.fixture + def decision_task_handler(self, mock_client, registry): + """Create a DecisionTaskHandler instance.""" + return DecisionTaskHandler( + client=mock_client, + task_list="test-task-list", + registry=registry, + identity="test-worker" + ) + + def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + """Create a mock decision task with history.""" + # Create workflow execution + workflow_execution = WorkflowExecution() + workflow_execution.workflow_id = workflow_id + workflow_execution.run_id = run_id + + # Create workflow type + workflow_type_obj = WorkflowType() + workflow_type_obj.name = workflow_type + + # Create workflow execution started event + started_event = WorkflowExecutionStartedEventAttributes() + input_payload = Payload(data=b'"test-input"') + started_event.input.CopyFrom(input_payload) + + history_event = HistoryEvent() + history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) + + # Create history + history = History() + history.events.append(history_event) + + # Create decision task + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + decision_task.workflow_execution.CopyFrom(workflow_execution) + decision_task.workflow_type.CopyFrom(workflow_type_obj) + decision_task.history.CopyFrom(history) + + return decision_task + + @pytest.mark.asyncio + async def test_handle_decision_task_success(self, decision_task_handler, mock_client): + """Test successful decision task handling.""" + # Create a mock decision task + decision_task = self.create_mock_decision_task() + + # Mock the workflow engine to return some decisions + # Mock the workflow engine creation and execution + mock_engine = Mock() + # Create a proper Decision object + decision = Decision() + mock_engine.process_decision = AsyncMock(return_value=Mock( + decisions=[decision], # Proper Decision object + )) + + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): + # Handle the decision task + await decision_task_handler._handle_task_implementation(decision_task) + + # Verify the workflow engine was called + mock_engine.process_decision.assert_called_once_with(decision_task) + + # Verify the response was sent + mock_client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_decision_task_workflow_not_found(self, decision_task_handler, mock_client): + """Test decision task handling when workflow is not found in registry.""" + # Create a decision task with unknown workflow type + decision_task = self.create_mock_decision_task(workflow_type="unknown_workflow") + + # Handle the decision task + await decision_task_handler.handle_task(decision_task) + + # Verify failure response was sent + mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + # Verify the failure request has the correct cause + call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == 14 # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + + @pytest.mark.asyncio + async def test_handle_decision_task_missing_workflow_execution(self, decision_task_handler, mock_client): + """Test decision task handling when workflow execution is missing.""" + # Create a decision task without workflow execution + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + # No workflow_execution set + + # Handle the decision task + await decision_task_handler.handle_task(decision_task) + + # Verify failure response was sent + mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + # Verify the failure request has the correct cause + call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.cause == 14 # DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE + + @pytest.mark.asyncio + async def test_workflow_engine_creation_each_task(self, decision_task_handler, mock_client): + """Test that workflow engines are created for each task.""" + decision_task = self.create_mock_decision_task() + + with patch('cadence.worker._decision_task_handler.WorkflowEngine') as mock_engine_class: + mock_engine = Mock() + mock_engine.process_decision = AsyncMock(return_value=Mock( + decisions=[], + )) + mock_engine_class.return_value = mock_engine + + # Handle the same decision task twice + await decision_task_handler._handle_task_implementation(decision_task) + await decision_task_handler._handle_task_implementation(decision_task) + + # Verify engine was created twice (once for each task) + assert mock_engine_class.call_count == 2 + + # Verify engine was called twice + assert mock_engine.process_decision.call_count == 2 + + + @pytest.mark.asyncio + async def test_decision_task_failure_handling(self, decision_task_handler, mock_client): + """Test decision task failure handling.""" + decision_task = self.create_mock_decision_task() + + # Mock the workflow engine to raise an exception + with patch('cadence.worker._decision_task_handler.WorkflowEngine') as mock_engine_class: + mock_engine = Mock() + mock_engine.process_decision = AsyncMock(side_effect=Exception("Test error")) + mock_engine_class.return_value = mock_engine + + # Handle the decision task - this should catch the exception + await decision_task_handler.handle_task(decision_task) + + # Verify failure response was sent + mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + def test_decision_task_handler_initialization(self, decision_task_handler): + """Test DecisionTaskHandler initialization.""" + assert decision_task_handler._registry is not None + assert decision_task_handler._identity == "test-worker" + + @pytest.mark.asyncio + async def test_respond_decision_task_completed(self, decision_task_handler, mock_client): + """Test decision task completion response.""" + decision_task = self.create_mock_decision_task() + + # Create mock decision result + decision_result = Mock() + decision_result.decisions = [Decision()] # Proper Decision object + + # Call the response method + await decision_task_handler._respond_decision_task_completed(decision_task, decision_result) + + # Verify the response was sent + mock_client.worker_stub.RespondDecisionTaskCompleted.assert_called_once() + + # Verify the request parameters + call_args = mock_client.worker_stub.RespondDecisionTaskCompleted.call_args[0][0] + assert call_args.task_token == b"test-task-token" + assert call_args.identity == "test-worker" + assert len(call_args.decisions) == 1 + + @pytest.mark.asyncio + async def test_respond_decision_task_failed(self, decision_task_handler, mock_client): + """Test decision task failure response.""" + decision_task = self.create_mock_decision_task() + error = ValueError("Test error") + + # Call the failure method + await decision_task_handler.handle_task_failure(decision_task, error) + + # Verify the failure response was sent + mock_client.worker_stub.RespondDecisionTaskFailed.assert_called_once() + + # Verify the request parameters + call_args = mock_client.worker_stub.RespondDecisionTaskFailed.call_args[0][0] + assert call_args.task_token == b"test-task-token" + assert call_args.identity == "test-worker" + assert call_args.cause == 2 # BAD_SCHEDULE_ACTIVITY_ATTRIBUTES for ValueError + assert b"Test error" in call_args.details.data diff --git a/tests/cadence/worker/test_decision_worker_integration.py b/tests/cadence/worker/test_decision_worker_integration.py new file mode 100644 index 0000000..85c55d2 --- /dev/null +++ b/tests/cadence/worker/test_decision_worker_integration.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Integration tests for DecisionWorker with DecisionTaskHandler. +""" + +import asyncio +import pytest +from unittest.mock import Mock, AsyncMock, patch +from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse +from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType +from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes +from cadence.worker._decision import DecisionWorker +from cadence.worker._registry import Registry +from cadence.client import Client + + +class TestDecisionWorkerIntegration: + """Integration tests for DecisionWorker with DecisionTaskHandler.""" + + @pytest.fixture + def mock_client(self): + """Create a mock Cadence client.""" + client = Mock(spec=Client) + client.domain = "test-domain" + client.data_converter = Mock() + client.data_converter.from_data = AsyncMock(return_value=["test-input"]) + client.worker_stub = Mock() + client.worker_stub.PollForDecisionTask = AsyncMock() + client.worker_stub.RespondDecisionTaskCompleted = AsyncMock() + client.worker_stub.RespondDecisionTaskFailed = AsyncMock() + return client + + @pytest.fixture + def registry(self): + """Create a registry with a test workflow.""" + reg = Registry() + + @reg.workflow + def test_workflow(input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + + return reg + + @pytest.fixture + def decision_worker(self, mock_client, registry): + """Create a DecisionWorker instance.""" + options = { + "identity": "test-worker", + "max_concurrent_decision_task_execution_size": 1, + "decision_task_pollers": 1 + } + return DecisionWorker( + client=mock_client, + task_list="test-task-list", + registry=registry, + options=options + ) + + def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): + """Create a mock decision task with history.""" + # Create workflow execution + workflow_execution = WorkflowExecution() + workflow_execution.workflow_id = workflow_id + workflow_execution.run_id = run_id + + # Create workflow type + workflow_type_obj = WorkflowType() + workflow_type_obj.name = workflow_type + + # Create workflow execution started event + started_event = WorkflowExecutionStartedEventAttributes() + input_payload = Payload(data=b'"test-input"') + started_event.input.CopyFrom(input_payload) + + history_event = HistoryEvent() + history_event.workflow_execution_started_event_attributes.CopyFrom(started_event) + + # Create history + history = History() + history.events.append(history_event) + + # Create decision task + decision_task = PollForDecisionTaskResponse() + decision_task.task_token = b"test-task-token" + decision_task.workflow_execution.CopyFrom(workflow_execution) + decision_task.workflow_type.CopyFrom(workflow_type_obj) + decision_task.history.CopyFrom(history) + + return decision_task + + @pytest.mark.asyncio + async def test_decision_worker_poll_and_execute(self, decision_worker, mock_client): + """Test decision worker polling and executing tasks.""" + # Create a mock decision task + decision_task = self.create_mock_decision_task() + + # Mock the poll to return the decision task + mock_client.worker_stub.PollForDecisionTask.return_value = decision_task + + # Mock the decision handler + with patch.object(decision_worker, '_decision_handler') as mock_handler: + mock_handler.handle_task = AsyncMock() + + # Run the poll and execute + await decision_worker._poll() + await decision_worker._execute(decision_task) + + # Verify the poll was called + mock_client.worker_stub.PollForDecisionTask.assert_called_once() + + # Verify the handler was called + mock_handler.handle_task.assert_called_once_with(decision_task) + + @pytest.mark.asyncio + async def test_decision_worker_poll_no_task(self, decision_worker, mock_client): + """Test decision worker polling when no task is available.""" + # Mock the poll to return None (no task) + mock_client.worker_stub.PollForDecisionTask.return_value = None + + # Run the poll + result = await decision_worker._poll() + + # Verify no task was returned + assert result is None + + @pytest.mark.asyncio + async def test_decision_worker_poll_with_task_token(self, decision_worker, mock_client): + """Test decision worker polling when task has token.""" + # Create a decision task with token + decision_task = self.create_mock_decision_task() + decision_task.task_token = b"valid-token" + + # Mock the poll to return the decision task + mock_client.worker_stub.PollForDecisionTask.return_value = decision_task + + # Run the poll + result = await decision_worker._poll() + + # Verify the task was returned + assert result == decision_task + + @pytest.mark.asyncio + async def test_decision_worker_poll_without_task_token(self, decision_worker, mock_client): + """Test decision worker polling when task has no token.""" + # Create a decision task without token + decision_task = self.create_mock_decision_task() + decision_task.task_token = b"" # Empty token + + # Mock the poll to return the decision task + mock_client.worker_stub.PollForDecisionTask.return_value = decision_task + + # Run the poll + result = await decision_worker._poll() + + # Verify no task was returned + assert result is None + + @pytest.mark.asyncio + async def test_decision_worker_execute_success(self, decision_worker, mock_client): + """Test successful decision task execution.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision handler + with patch.object(decision_worker, '_decision_handler') as mock_handler: + mock_handler.handle_task = AsyncMock() + + # Execute the task + await decision_worker._execute(decision_task) + + # Verify the handler was called + mock_handler.handle_task.assert_called_once_with(decision_task) + + @pytest.mark.asyncio + async def test_decision_worker_execute_handler_error(self, decision_worker, mock_client): + """Test decision task execution when handler raises an error.""" + decision_task = self.create_mock_decision_task() + + # Mock the decision handler to raise an error + with patch.object(decision_worker, '_decision_handler') as mock_handler: + mock_handler.handle_task = AsyncMock(side_effect=Exception("Handler error")) + + # Execute the task - should raise the exception + with pytest.raises(Exception, match="Handler error"): + await decision_worker._execute(decision_task) + + # Verify the handler was called + mock_handler.handle_task.assert_called_once_with(decision_task) + + def test_decision_worker_initialization(self, decision_worker, mock_client, registry): + """Test DecisionWorker initialization.""" + assert decision_worker._client == mock_client + assert decision_worker._task_list == "test-task-list" + assert decision_worker._identity == "test-worker" + assert decision_worker._registry == registry + assert decision_worker._decision_handler is not None + assert decision_worker._poller is not None + + @pytest.mark.asyncio + async def test_decision_worker_run(self, decision_worker, mock_client): + """Test DecisionWorker run method.""" + # Mock the poller to complete immediately + with patch.object(decision_worker._poller, 'run', new_callable=AsyncMock) as mock_poller_run: + # Run the worker + await decision_worker.run() + + # Verify the poller was run + mock_poller_run.assert_called_once() + + @pytest.mark.asyncio + async def test_decision_worker_integration_flow(self, decision_worker, mock_client): + """Test the complete integration flow from poll to execute.""" + # Create a mock decision task + decision_task = self.create_mock_decision_task() + + # Mock the poll to return the decision task + mock_client.worker_stub.PollForDecisionTask.return_value = decision_task + + # Mock the decision handler + with patch.object(decision_worker, '_decision_handler') as mock_handler: + mock_handler.handle_task = AsyncMock() + + # Test the complete flow + # 1. Poll for task + polled_task = await decision_worker._poll() + assert polled_task == decision_task + + # 2. Execute the task + await decision_worker._execute(polled_task) + + # 3. Verify the handler was called + mock_handler.handle_task.assert_called_once_with(decision_task) + + @pytest.mark.asyncio + async def test_decision_worker_with_different_workflow_types(self, decision_worker, mock_client, registry): + """Test decision worker with different workflow types.""" + # Add another workflow to the registry + @registry.workflow + def another_workflow(input_data): + return f"another-processed: {input_data}" + + # Create decision tasks for different workflow types + task1 = self.create_mock_decision_task(workflow_type="test_workflow") + task2 = self.create_mock_decision_task(workflow_type="another_workflow") + + # Mock the decision handler + with patch.object(decision_worker, '_decision_handler') as mock_handler: + mock_handler.handle_task = AsyncMock() + + # Execute both tasks + await decision_worker._execute(task1) + await decision_worker._execute(task2) + + # Verify both tasks were handled + assert mock_handler.handle_task.call_count == 2 + + @pytest.mark.asyncio + async def test_decision_worker_poll_timeout(self, decision_worker, mock_client): + """Test decision worker polling with timeout.""" + # Mock the poll to raise a timeout exception + mock_client.worker_stub.PollForDecisionTask.side_effect = asyncio.TimeoutError("Poll timeout") + + # Run the poll - should handle timeout gracefully + with pytest.raises(asyncio.TimeoutError): + await decision_worker._poll() + + def test_decision_worker_options_handling(self, mock_client, registry): + """Test DecisionWorker with various options.""" + options = { + "identity": "custom-worker", + "max_concurrent_decision_task_execution_size": 5, + "decision_task_pollers": 3 + } + + worker = DecisionWorker( + client=mock_client, + task_list="custom-task-list", + registry=registry, + options=options + ) + + # Verify options were applied + assert worker._identity == "custom-worker" + assert worker._task_list == "custom-task-list" + assert worker._registry == registry diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index 64d877f..8e6aef9 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -53,17 +53,23 @@ def sample_decision_task(self): task.workflow_execution.run_id = "test_run_id" task.workflow_type = Mock() task.workflow_type.name = "TestWorkflow" + # Add the missing attributes that are now accessed directly + task.started_event_id = 1 + task.attempt = 1 return task @pytest.mark.asyncio async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow from base handler through decision handler.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) @@ -81,11 +87,14 @@ async def test_full_task_handling_flow_success(self, handler, sample_decision_ta async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow when an error occurs.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine): @@ -102,11 +111,14 @@ async def test_full_task_handling_flow_with_error(self, handler, sample_decision async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): """Test that context activation works correctly in the integration.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) @@ -133,7 +145,9 @@ def track_context_activation(): async def test_multiple_workflow_executions(self, handler, mock_registry): """Test handling multiple workflow executions creates new engines for each.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Create multiple decision tasks for different workflows @@ -144,6 +158,8 @@ async def test_multiple_workflow_executions(self, handler, mock_registry): task1.workflow_execution.run_id = "run1" task1.workflow_type = Mock() task1.workflow_type.name = "TestWorkflow" + task1.started_event_id = 1 + task1.attempt = 1 task2 = Mock(spec=PollForDecisionTaskResponse) task2.task_token = b"task2_token" @@ -152,9 +168,12 @@ async def test_multiple_workflow_executions(self, handler, mock_registry): task2.workflow_execution.run_id = "run2" task2.workflow_type = Mock() task2.workflow_type.name = "TestWorkflow" + task2.started_event_id = 2 + task2.attempt = 1 # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] @@ -176,11 +195,14 @@ async def test_multiple_workflow_executions(self, handler, mock_registry): async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): """Test workflow engine creation integration.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) @@ -197,11 +219,14 @@ async def test_workflow_engine_creation_integration(self, handler, sample_decisi async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): """Test that context cleanup happens even when errors occur.""" # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_engine.process_decision = AsyncMock(side_effect=RuntimeError("Workflow processing failed")) # Track context cleanup @@ -231,7 +256,9 @@ async def test_concurrent_task_handling(self, handler, mock_registry): import asyncio # Mock workflow function - mock_workflow_func = Mock() + def mock_workflow_func(input_data): + return f"processed: {input_data}" + mock_registry.get_workflow.return_value = mock_workflow_func # Create multiple tasks @@ -244,10 +271,13 @@ async def test_concurrent_task_handling(self, handler, mock_registry): task.workflow_execution.run_id = f"run{i}" task.workflow_type = Mock() task.workflow_type.name = "TestWorkflow" + task.started_event_id = i + 1 + task.attempt = 1 tasks.append(task) # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) + mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result)