Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cadence/_internal/workflow/workflow_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Callable, Any

from cadence._internal.workflow.context import Context
from cadence.api.v1.decision_pb2 import Decision
Expand All @@ -10,12 +11,15 @@
@dataclass
class DecisionResult:
decisions: list[Decision]
force_create_new_decision_task: bool = False
query_results: Optional[dict] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed at the moment. This is related to local activities and workflow query. We might have better solutions later than copying java legacy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those fields are used in RespondDecisionTaskCompletedRequest but we can remove them for now


class WorkflowEngine:
def __init__(self, info: WorkflowInfo, client: Client):
def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Optional[Callable[..., Any]] = None):
self._context = Context(client, info)
self._workflow_func = workflow_func

# TODO: Implement this
def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult:
async def process_decision(self, decision_task: PollForDecisionTaskResponse) -> DecisionResult:
with self._context._activate():
return DecisionResult(decisions=[])
70 changes: 70 additions & 0 deletions cadence/worker/_base_task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from abc import ABC, abstractmethod
from typing import TypeVar, Generic

logger = logging.getLogger(__name__)

T = TypeVar('T')

class BaseTaskHandler(ABC, Generic[T]):
"""
Base task handler that provides common functionality for processing tasks.

This abstract class defines the interface and common behavior for task handlers
that process different types of tasks (workflow decisions, activities, etc.).
"""

def __init__(self, client, task_list: str, identity: str, **options):
"""
Initialize the base task handler.

Args:
client: The Cadence client instance
task_list: The task list name
identity: Worker identity
**options: Additional options for the handler
"""
self._client = client
self._task_list = task_list
self._identity = identity
self._options = options

async def handle_task(self, task: T) -> None:
"""
Handle a single task.

This method provides the base implementation for task handling that includes:
- Error handling
- Cleanup

Args:
task: The task to handle
"""
try:
# Handle the task implementation
await self._handle_task_implementation(task)

except Exception as e:
logger.exception(f"Error handling task: {e}")
await self.handle_task_failure(task, e)

@abstractmethod
async def _handle_task_implementation(self, task: T) -> None:
"""
Handle the actual task implementation.

Args:
task: The task to handle
"""
pass

@abstractmethod
async def handle_task_failure(self, task: T, error: Exception) -> None:
"""
Handle task processing failure.

Args:
task: The task that failed
error: The exception that occurred
"""
pass
173 changes: 173 additions & 0 deletions cadence/worker/_decision_task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import logging
from typing import Dict

from cadence.api.v1.common_pb2 import Payload
from cadence.api.v1.service_worker_pb2 import (
PollForDecisionTaskResponse,
RespondDecisionTaskCompletedRequest,
RespondDecisionTaskFailedRequest
)
from cadence.api.v1.workflow_pb2 import DecisionTaskFailedCause
from cadence.client import Client
from cadence.worker._base_task_handler import BaseTaskHandler
from cadence._internal.workflow.workflow_engine import WorkflowEngine, DecisionResult
from cadence.workflow import WorkflowInfo
from cadence.worker._registry import Registry

logger = logging.getLogger(__name__)

class DecisionTaskHandler(BaseTaskHandler[PollForDecisionTaskResponse]):
"""
Task handler for processing decision tasks.
This handler processes decision tasks and generates decisions using the workflow engine.
"""

def __init__(self, client: Client, task_list: str, registry: Registry, identity: str = "unknown", **options):
"""
Initialize the decision task handler.
Args:
client: The Cadence client instance
task_list: The task list name
registry: Registry containing workflow functions
identity: The worker identity
**options: Additional options for the handler
"""
super().__init__(client, task_list, identity, **options)
self._registry = registry
self._workflow_engines: Dict[str, WorkflowEngine] = {}


async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -> None:
"""
Handle a decision task implementation.
Args:
task: The decision task to handle
"""
# Extract workflow execution info
workflow_execution = task.workflow_execution
workflow_type = task.workflow_type

if not workflow_execution or not workflow_type:
logger.error("Decision task missing workflow execution or type")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: log task?

raise ValueError("Missing workflow execution or type")

workflow_id = workflow_execution.workflow_id
run_id = workflow_execution.run_id
workflow_type_name = workflow_type.name

logger.info(f"Processing decision task for workflow {workflow_id} (type: {workflow_type_name})")

# Get or create workflow engine for this workflow execution
engine_key = f"{workflow_id}:{run_id}"
if engine_key not in self._workflow_engines:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we don't support sticky cache yet. This will cause memory leak without exit logic. Maybe just remove _workflow_engines for now and add it when sticky cache is implemented.

# Get the workflow function from registry
try:
workflow_func = self._registry.get_workflow(workflow_type_name)
except KeyError:
logger.error(f"Workflow type '{workflow_type_name}' not found in registry")
raise KeyError(f"Workflow type '{workflow_type_name}' not found")

# Create workflow info and engine
workflow_info = WorkflowInfo(
workflow_type=workflow_type_name,
workflow_domain=self._client.domain,
workflow_id=workflow_id,
workflow_run_id=run_id
)

self._workflow_engines[engine_key] = WorkflowEngine(
info=workflow_info,
client=self._client,
workflow_func=workflow_func
)

# Create workflow context and execute with it active
workflow_engine = self._workflow_engines[engine_key]

decision_result = await workflow_engine.process_decision(task)

# Respond with the decisions
await self._respond_decision_task_completed(task, decision_result)

logger.info(f"Successfully processed decision task for workflow {workflow_id}")

async def handle_task_failure(self, task: PollForDecisionTaskResponse, error: Exception) -> None:
"""
Handle decision task processing failure.
Args:
task: The task that failed
error: The exception that occurred
"""
try:
logger.error(f"Decision task failed: {error}")

# Determine the failure cause
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_UNHANDLED_DECISION
if isinstance(error, KeyError):
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_WORKFLOW_WORKER_UNHANDLED_FAILURE
elif isinstance(error, ValueError):
cause = DecisionTaskFailedCause.DECISION_TASK_FAILED_CAUSE_BAD_SCHEDULE_ACTIVITY_ATTRIBUTES

# Create error details
error_message = str(error).encode('utf-8')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO here. I think we need a data converter for this.

details = Payload(data=error_message)

# Respond with failure
await self._client.worker_stub.RespondDecisionTaskFailed(
RespondDecisionTaskFailedRequest(
task_token=task.task_token,
cause=cause,
identity=self._identity,
details=details
)
)

logger.info("Decision task failure response sent")

except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just don't handle exception at all and let worker to handle it

logger.exception("Error handling decision task failure")

async def _respond_decision_task_completed(self, task: PollForDecisionTaskResponse, decision_result: DecisionResult) -> None:
"""
Respond to the service that the decision task has been completed.
Args:
task: The original decision task
decision_result: The result containing decisions and query results
"""
try:
request = RespondDecisionTaskCompletedRequest(
task_token=task.task_token,
decisions=decision_result.decisions,
identity=self._identity,
return_new_decision_task=decision_result.force_create_new_decision_task,
force_create_new_decision_task=decision_result.force_create_new_decision_task
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these are for locally dispatched activities. We can add it later so it's cleaner. Same thing for query related logic

)

# Add query results if present
if decision_result.query_results:
request.query_results.update(decision_result.query_results)

await self._client.worker_stub.RespondDecisionTaskCompleted(request)
logger.debug(f"Decision task completed with {len(decision_result.decisions)} decisions")

except Exception:
logger.exception("Error responding to decision task completion")
raise

def cleanup_workflow_engine(self, workflow_id: str, run_id: str) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not used anywhere

"""
Clean up a workflow engine when workflow execution is complete.
Args:
workflow_id: The workflow ID
run_id: The run ID
"""
engine_key = f"{workflow_id}:{run_id}"
if engine_key in self._workflow_engines:
del self._workflow_engines[engine_key]
logger.debug(f"Cleaned up workflow engine for {workflow_id}:{run_id}")
124 changes: 124 additions & 0 deletions tests/cadence/worker/test_base_task_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/usr/bin/env python3
"""
Unit tests for BaseTaskHandler class.
"""

import pytest
from unittest.mock import Mock

from cadence.worker._base_task_handler import BaseTaskHandler


class ConcreteTaskHandler(BaseTaskHandler[str]):
"""Concrete implementation of BaseTaskHandler for testing."""

def __init__(self, client, task_list: str, identity: str, **options):
super().__init__(client, task_list, identity, **options)
self._handle_task_implementation_called = False
self._handle_task_failure_called = False
self._last_task: str = ""
self._last_error: Exception | None = None

async def _handle_task_implementation(self, task: str) -> None:
"""Test implementation of task handling."""
self._handle_task_implementation_called = True
self._last_task = task
if task == "raise_error":
raise ValueError("Test error")

async def handle_task_failure(self, task: str, error: Exception) -> None:
"""Test implementation of task failure handling."""
self._handle_task_failure_called = True
self._last_task = task
self._last_error = error


class TestBaseTaskHandler:
"""Test cases for BaseTaskHandler."""

def test_initialization(self):
"""Test BaseTaskHandler initialization."""
client = Mock()
handler = ConcreteTaskHandler(
client=client,
task_list="test_task_list",
identity="test_identity",
option1="value1",
option2="value2"
)

assert handler._client == client
assert handler._task_list == "test_task_list"
assert handler._identity == "test_identity"
assert handler._options == {"option1": "value1", "option2": "value2"}

@pytest.mark.asyncio
async def test_handle_task_success(self):
"""Test successful task handling."""
client = Mock()
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")

await handler.handle_task("test_task")

# Verify implementation was called
assert handler._handle_task_implementation_called
assert not handler._handle_task_failure_called
assert handler._last_task == "test_task"
assert handler._last_error is None

@pytest.mark.asyncio
async def test_handle_task_failure(self):
"""Test task handling with error."""
client = Mock()
handler = ConcreteTaskHandler(client, "test_task_list", "test_identity")

await handler.handle_task("raise_error")

# Verify error handling was called
assert handler._handle_task_implementation_called
assert handler._handle_task_failure_called
assert handler._last_task == "raise_error"
assert isinstance(handler._last_error, ValueError)
assert str(handler._last_error) == "Test error"


@pytest.mark.asyncio
async def test_abstract_methods_not_implemented(self):
"""Test that abstract methods raise NotImplementedError when not implemented."""
client = Mock()

class IncompleteHandler(BaseTaskHandler[str]):
async def _handle_task_implementation(self, task: str) -> None:
raise NotImplementedError()

async def handle_task_failure(self, task: str, error: Exception) -> None:
raise NotImplementedError()

handler = IncompleteHandler(client, "test_task_list", "test_identity")

with pytest.raises(NotImplementedError):
await handler._handle_task_implementation("test")

with pytest.raises(NotImplementedError):
await handler.handle_task_failure("test", Exception("test"))


@pytest.mark.asyncio
async def test_generic_type_parameter(self):
"""Test that the generic type parameter works correctly."""
client = Mock()

class IntHandler(BaseTaskHandler[int]):
async def _handle_task_implementation(self, task: int) -> None:
pass

async def handle_task_failure(self, task: int, error: Exception) -> None:
pass

handler = IntHandler(client, "test_task_list", "test_identity")

# Should accept int tasks
await handler.handle_task(42)

# Type checker should catch type mismatches (this is more of a static analysis test)
# In runtime, Python won't enforce the type, but the type hints are there for static analysis
Loading