diff --git a/cadence/__init__.py b/cadence/__init__.py index 175f01b..abf4bbd 100644 --- a/cadence/__init__.py +++ b/cadence/__init__.py @@ -6,9 +6,13 @@ # Import main client functionality from .client import Client +from .worker import Registry +from .workflow import workflow __version__ = "0.1.0" __all__ = [ "Client", + "Registry", + "workflow", ] diff --git a/cadence/_internal/workflow/workflow_engine.py b/cadence/_internal/workflow/workflow_engine.py index 2456cc1..7eff5c2 100644 --- a/cadence/_internal/workflow/workflow_engine.py +++ b/cadence/_internal/workflow/workflow_engine.py @@ -20,9 +20,12 @@ class DecisionResult: decisions: list[Decision] class WorkflowEngine: - def __init__(self, info: WorkflowInfo, client: Client, workflow_func: Callable[[Any], Any] | None = None): + def __init__(self, info: WorkflowInfo, client: Client, workflow_definition=None): self._context = Context(client, info) - self._workflow_func = workflow_func + self._workflow_definition = workflow_definition + self._workflow_instance = None + if workflow_definition: + self._workflow_instance = workflow_definition.cls() self._decision_manager = DecisionManager() self._decisions_helper = DecisionsHelper(self._decision_manager) self._is_workflow_complete = False @@ -250,19 +253,17 @@ def _fallback_process_workflow_history(self, history) -> None: async def _execute_workflow_function(self, decision_task: PollForDecisionTaskResponse) -> None: """ Execute the workflow function to generate new decisions. - + This blocks until the workflow schedules an activity or completes. - + Args: decision_task: The decision task containing workflow context """ try: - # Execute the workflow function - # The workflow function should block until it schedules an activity - workflow_func = self._workflow_func - if workflow_func is None: + # Execute the workflow function from the workflow instance + if self._workflow_definition is None or self._workflow_instance is None: logger.warning( - "No workflow function available", + "No workflow definition or instance available", extra={ "workflow_type": self._context.info().workflow_type, "workflow_id": self._context.info().workflow_id, @@ -271,6 +272,9 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes ) return + # Get the workflow run method from the instance + workflow_func = self._workflow_definition.get_run_method(self._workflow_instance) + # Extract workflow input from history workflow_input = await self._extract_workflow_input(decision_task) @@ -290,7 +294,7 @@ async def _execute_workflow_function(self, decision_task: PollForDecisionTaskRes "completion_type": "success" } ) - + except Exception as e: logger.error( "Error executing workflow function", diff --git a/cadence/worker/_decision_task_handler.py b/cadence/worker/_decision_task_handler.py index d35ee66..62f0edb 100644 --- a/cadence/worker/_decision_task_handler.py +++ b/cadence/worker/_decision_task_handler.py @@ -76,7 +76,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - ) try: - workflow_func = self._registry.get_workflow(workflow_type_name) + workflow_definition = self._registry.get_workflow(workflow_type_name) except KeyError: logger.error( "Workflow type not found in registry", @@ -103,9 +103,9 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) - workflow_engine = self._workflow_engines.get(cache_key) if workflow_engine is None: workflow_engine = WorkflowEngine( - info=workflow_info, - client=self._client, - workflow_func=workflow_func + info=workflow_info, + client=self._client, + workflow_definition=workflow_definition ) self._workflow_engines[cache_key] = workflow_engine diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index d60521d..816caad 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -7,8 +7,9 @@ """ import logging -from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload +from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload, Type, Union from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions logger = logging.getLogger(__name__) @@ -28,53 +29,58 @@ class Registry: def __init__(self) -> None: """Initialize the registry.""" - self._workflows: Dict[str, Callable] = {} + self._workflows: Dict[str, WorkflowDefinition] = {} self._activities: Dict[str, ActivityDefinition] = {} self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping def workflow( self, - func: Optional[Callable] = None, + cls: Optional[Type] = None, **kwargs: Unpack[RegisterWorkflowOptions] - ) -> Callable: + ) -> Union[Type, Callable[[Type], Type]]: """ - Register a workflow function. - + Register a workflow class. + This method can be used as a decorator or called directly. - + Only supports class-based workflows. + Args: - func: The workflow function to register + cls: The workflow class to register **kwargs: Options for registration (name, alias) - + Returns: - The decorated function or the function itself - + The decorated class + Raises: KeyError: If workflow name already exists + ValueError: If class workflow is invalid """ options = RegisterWorkflowOptions(**kwargs) - - def decorator(f: Callable) -> Callable: - workflow_name = options.get('name') or f.__name__ - + + def decorator(target: Type) -> Type: + workflow_name = options.get('name') or target.__name__ + if workflow_name in self._workflows: raise KeyError(f"Workflow '{workflow_name}' is already registered") - - self._workflows[workflow_name] = f - + + # Create WorkflowDefinition with type information + workflow_opts = WorkflowDefinitionOptions(name=workflow_name) + workflow_def = WorkflowDefinition.wrap(target, workflow_opts) + self._workflows[workflow_name] = workflow_def + # Register alias if provided alias = options.get('alias') if alias: if alias in self._workflow_aliases: raise KeyError(f"Workflow alias '{alias}' is already registered") self._workflow_aliases[alias] = workflow_name - + logger.info(f"Registered workflow '{workflow_name}'") - return f - - if func is None: + return target + + if cls is None: return decorator - return decorator(func) + return decorator(cls) @overload def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: @@ -135,25 +141,25 @@ def _register_activity(self, defn: ActivityDefinition) -> None: self._activities[defn.name] = defn - def get_workflow(self, name: str) -> Callable: + def get_workflow(self, name: str) -> WorkflowDefinition: """ Get a registered workflow by name. - + Args: name: Name or alias of the workflow - + Returns: - The workflow function - + The workflow definition + Raises: KeyError: If workflow is not found """ # Check if it's an alias actual_name = self._workflow_aliases.get(name, name) - + if actual_name not in self._workflows: raise KeyError(f"Workflow '{name}' not found in registry") - + return self._workflows[actual_name] def get_activity(self, name: str) -> ActivityDefinition: diff --git a/cadence/workflow.py b/cadence/workflow.py index 51b968f..22fd866 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -2,10 +2,118 @@ from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass -from typing import Iterator +from typing import Iterator, Callable, TypeVar, TypedDict, Type, cast, Any +from functools import wraps from cadence.client import Client +T = TypeVar('T') + + +class WorkflowDefinitionOptions(TypedDict, total=False): + """Options for defining a workflow.""" + name: str + + +class WorkflowDefinition: + """ + Definition of a workflow class with metadata. + + Similar to ActivityDefinition but for workflow classes. + Provides type safety and metadata for workflow classes. + """ + + def __init__(self, cls: Type, name: str): + self._cls = cls + self._name = name + + @property + def name(self) -> str: + """Get the workflow name.""" + return self._name + + @property + def cls(self) -> Type: + """Get the workflow class.""" + return self._cls + + def get_run_method(self, instance: Any) -> Callable: + """Get the workflow run method from an instance of the workflow class.""" + for attr_name in dir(instance): + if attr_name.startswith('_'): + continue + attr = getattr(instance, attr_name) + if callable(attr) and hasattr(attr, '_workflow_run'): + return cast(Callable, attr) + raise ValueError(f"No @workflow.run method found in class {self._cls.__name__}") + + @staticmethod + def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition': + """ + Wrap a class as a WorkflowDefinition. + + Args: + cls: The workflow class to wrap + opts: Options for the workflow definition + + Returns: + A WorkflowDefinition instance + + Raises: + ValueError: If no run method is found or multiple run methods exist + """ + name = cls.__name__ + if "name" in opts and opts["name"]: + name = opts["name"] + + # Validate that the class has exactly one run method + run_method_count = 0 + for attr_name in dir(cls): + if attr_name.startswith('_'): + continue + + attr = getattr(cls, attr_name) + if not callable(attr): + continue + + # Check for workflow run method + if hasattr(attr, '_workflow_run'): + run_method_count += 1 + + if run_method_count == 0: + raise ValueError(f"No @workflow.run method found in class {cls.__name__}") + elif run_method_count > 1: + raise ValueError(f"Multiple @workflow.run methods found in class {cls.__name__}") + + return WorkflowDefinition(cls, name) + + +def run(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to mark a method as the main workflow run method. + + Args: + func: The method to mark as the workflow run method + + Returns: + The decorated method with workflow run metadata + """ + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Attach metadata to the function + wrapper._workflow_run = True # type: ignore + return wrapper + + +# Create a simple namespace object for the workflow decorators +class _WorkflowNamespace: + run = staticmethod(run) + +workflow = _WorkflowNamespace() + + @dataclass class WorkflowInfo: workflow_type: str diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index cb1f449..ecf7c13 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -36,19 +36,25 @@ def workflow_info(self): ) @pytest.fixture - def mock_workflow_func(self): - """Create a mock workflow function.""" - def workflow_func(input_data): - return f"processed: {input_data}" - return workflow_func + def mock_workflow_definition(self): + """Create a mock workflow definition.""" + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class TestWorkflow: + @workflow.run + def weird_name(self, input_data): + return f"processed: {input_data}" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + return WorkflowDefinition.wrap(TestWorkflow, workflow_opts) @pytest.fixture - def workflow_engine(self, mock_client, workflow_info, mock_workflow_func): + def workflow_engine(self, mock_client, workflow_info, mock_workflow_definition): """Create a WorkflowEngine instance.""" return WorkflowEngine( info=workflow_info, client=mock_client, - workflow_func=mock_workflow_func + workflow_definition=mock_workflow_definition ) def create_mock_decision_task(self, workflow_id="test-workflow", run_id="test-run", workflow_type="test_workflow"): @@ -211,10 +217,13 @@ async def test_extract_workflow_input_deserialization_error(self, workflow_engin def test_execute_workflow_function_sync(self, workflow_engine): """Test synchronous workflow function execution.""" input_data = "test-input" - + + # Get the workflow function from the instance + workflow_func = workflow_engine._workflow_definition.get_run_method(workflow_engine._workflow_instance) + # Execute the workflow function - result = workflow_engine._execute_workflow_function_once(workflow_engine._workflow_func, input_data) - + result = workflow_engine._execute_workflow_function_once(workflow_func, input_data) + # Verify the result assert result == "processed: test-input" @@ -239,20 +248,21 @@ def test_execute_workflow_function_none(self, workflow_engine): with pytest.raises(TypeError, match="'NoneType' object is not callable"): workflow_engine._execute_workflow_function_once(None, input_data) - def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_func): + def test_workflow_engine_initialization(self, workflow_engine, workflow_info, mock_client, mock_workflow_definition): """Test WorkflowEngine initialization.""" assert workflow_engine._context is not None - assert workflow_engine._workflow_func == mock_workflow_func + assert workflow_engine._workflow_definition == mock_workflow_definition + assert workflow_engine._workflow_instance is not None assert workflow_engine._decision_manager is not None assert workflow_engine._is_workflow_complete is False @pytest.mark.asyncio - async def test_workflow_engine_without_workflow_func(self, mock_client, workflow_info): - """Test WorkflowEngine without workflow function.""" + async def test_workflow_engine_without_workflow_definition(self, mock_client, workflow_info): + """Test WorkflowEngine without workflow definition.""" engine = WorkflowEngine( info=workflow_info, client=mock_client, - workflow_func=None + workflow_definition=None ) decision_task = self.create_mock_decision_task() @@ -269,12 +279,21 @@ async def test_workflow_engine_without_workflow_func(self, mock_client, workflow async def test_workflow_engine_workflow_completion(self, workflow_engine, mock_client): """Test workflow completion detection.""" decision_task = self.create_mock_decision_task() - - # Mock workflow function to return a result (indicating completion) - def completing_workflow_func(input_data): - return "workflow-completed" - - workflow_engine._workflow_func = completing_workflow_func + + # Create a workflow definition that returns a result (indicating completion) + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class CompletingWorkflow: + @workflow.run + def run(self, input_data): + return "workflow-completed" + + workflow_opts = WorkflowDefinitionOptions(name="completing_workflow") + completing_definition = WorkflowDefinition.wrap(CompletingWorkflow, workflow_opts) + + # Replace the workflow definition and instance + workflow_engine._workflow_definition = completing_definition + workflow_engine._workflow_instance = completing_definition.cls() with patch.object(workflow_engine._decision_manager, 'collect_pending_decisions', return_value=[]): # Process the decision diff --git a/tests/cadence/worker/test_decision_task_handler.py b/tests/cadence/worker/test_decision_task_handler.py index cd2b210..55b1e1f 100644 --- a/tests/cadence/worker/test_decision_task_handler.py +++ b/tests/cadence/worker/test_decision_task_handler.py @@ -82,9 +82,17 @@ def test_initialization(self, mock_client, mock_registry): @pytest.mark.asyncio async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry): """Test successful decision task handling.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -142,9 +150,17 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp @pytest.mark.asyncio async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry): """Test that decision task handler caches workflow engines for same workflow execution.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -172,9 +188,17 @@ async def test_handle_task_implementation_caches_engines(self, handler, sample_d @pytest.mark.asyncio async def test_handle_task_implementation_different_executions_get_separate_engines(self, handler, mock_registry): """Test that different workflow executions get separate engines.""" - # Mock workflow function - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create two different decision tasks task1 = Mock(spec=PollForDecisionTaskResponse) @@ -323,19 +347,28 @@ async def test_respond_decision_task_completed_error(self, handler, sample_decis @pytest.mark.asyncio async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry): """Test that WorkflowEngine is created with correct WorkflowInfo.""" - mock_workflow_func = Mock() - mock_registry.get_workflow.return_value = mock_workflow_func - + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition + mock_engine = Mock(spec=WorkflowEngine) mock_engine._is_workflow_complete = False # Add missing attribute mock_decision_result = Mock(spec=DecisionResult) mock_decision_result.decisions = [] mock_engine.process_decision = AsyncMock(return_value=mock_decision_result) - + with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class: with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class: await handler._handle_task_implementation(sample_decision_task) - + # Verify WorkflowInfo was created with correct parameters (called once for engine) assert mock_workflow_info_class.call_count == 1 for call in mock_workflow_info_class.call_args_list: @@ -345,10 +378,10 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample 'workflow_id': "test_workflow_id", 'workflow_run_id': "test_run_id" } - + # Verify WorkflowEngine was created with correct parameters mock_workflow_engine_class.assert_called_once() call_args = mock_workflow_engine_class.call_args assert call_args[1]['info'] is not None assert call_args[1]['client'] == handler._client - assert call_args[1]['workflow_func'] == mock_workflow_func + assert call_args[1]['workflow_definition'] == workflow_definition diff --git a/tests/cadence/worker/test_decision_task_handler_integration.py b/tests/cadence/worker/test_decision_task_handler_integration.py index b513a14..0327485 100644 --- a/tests/cadence/worker/test_decision_task_handler_integration.py +++ b/tests/cadence/worker/test_decision_task_handler_integration.py @@ -13,6 +13,7 @@ from cadence.api.v1.decision_pb2 import Decision from cadence.worker._decision_task_handler import DecisionTaskHandler from cadence.worker._registry import Registry +from cadence.workflow import workflow from cadence.client import Client @@ -35,12 +36,14 @@ def mock_client(self): def registry(self): """Create a registry with a test workflow.""" reg = Registry() - - @reg.workflow - def test_workflow(input_data): - """Simple test workflow that returns the input.""" - return f"processed: {input_data}" - + + @reg.workflow(name="test_workflow") + class TestWorkflow: + @workflow.run + async def run(self, input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + return reg @pytest.fixture diff --git a/tests/cadence/worker/test_decision_worker_integration.py b/tests/cadence/worker/test_decision_worker_integration.py index 85c55d2..712f312 100644 --- a/tests/cadence/worker/test_decision_worker_integration.py +++ b/tests/cadence/worker/test_decision_worker_integration.py @@ -11,6 +11,7 @@ from cadence.api.v1.history_pb2 import History, HistoryEvent, WorkflowExecutionStartedEventAttributes from cadence.worker._decision import DecisionWorker from cadence.worker._registry import Registry +from cadence.workflow import workflow from cadence.client import Client @@ -34,12 +35,14 @@ def mock_client(self): def registry(self): """Create a registry with a test workflow.""" reg = Registry() - + @reg.workflow - def test_workflow(input_data): - """Simple test workflow that returns the input.""" - return f"processed: {input_data}" - + class TestWorkflow: + @workflow.run + async def run(self, input_data): + """Simple test workflow that returns the input.""" + return f"processed: {input_data}" + return reg @pytest.fixture @@ -236,8 +239,10 @@ async def test_decision_worker_with_different_workflow_types(self, decision_work """Test decision worker with different workflow types.""" # Add another workflow to the registry @registry.workflow - def another_workflow(input_data): - return f"another-processed: {input_data}" + class AnotherWorkflow: + @workflow.run + async def run(self, input_data): + return f"another-processed: {input_data}" # Create decision tasks for different workflow types task1 = self.create_mock_decision_task(workflow_type="test_workflow") diff --git a/tests/cadence/worker/test_registry.py b/tests/cadence/worker/test_registry.py index 4a8973b..53c16f0 100644 --- a/tests/cadence/worker/test_registry.py +++ b/tests/cadence/worker/test_registry.py @@ -7,6 +7,7 @@ from cadence import activity from cadence.worker import Registry +from cadence.workflow import workflow, WorkflowDefinition from tests.cadence import common_activities @@ -21,24 +22,32 @@ def test_basic_registry_creation(self): with pytest.raises(KeyError): reg.get_activity("nonexistent") - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_basic_registration_and_retrieval(self, registration_type): - """Test basic registration and retrieval for both workflows and activities.""" + def test_basic_workflow_registration_and_retrieval(self): + """Test basic registration and retrieval for class-based workflows.""" reg = Registry() - - if registration_type == "workflow": - @reg.workflow - def test_func(): - return "test" - - func = reg.get_workflow("test_func") - else: - @reg.activity - def test_func(): + + @reg.workflow + class TestWorkflow: + @workflow.run + async def run(self): return "test" - - func = reg.get_activity(test_func.name) - + + # Registry stores WorkflowDefinition internally + workflow_def = reg.get_workflow("TestWorkflow") + # Verify it's actually a WorkflowDefinition + assert isinstance(workflow_def, WorkflowDefinition) + assert workflow_def.name == "TestWorkflow" + assert workflow_def.cls == TestWorkflow + + def test_basic_activity_registration_and_retrieval(self): + """Test basic registration and retrieval for activities.""" + reg = Registry() + + @reg.activity + def test_func(): + return "test" + + func = reg.get_activity(test_func.name) assert func() == "test" def test_direct_call_behavior(self): @@ -53,41 +62,47 @@ def test_func(): assert func() == "direct_call" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_not_found_error(self, registration_type): - """Test KeyError is raised when function not found.""" + def test_workflow_not_found_error(self): + """Test KeyError is raised when workflow not found.""" reg = Registry() - - if registration_type == "workflow": - with pytest.raises(KeyError): - reg.get_workflow("nonexistent") - else: - with pytest.raises(KeyError): - reg.get_activity("nonexistent") - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_duplicate_registration_error(self, registration_type): - """Test KeyError is raised for duplicate registrations.""" + with pytest.raises(KeyError): + reg.get_workflow("nonexistent") + + def test_activity_not_found_error(self): + """Test KeyError is raised when activity not found.""" reg = Registry() - - if registration_type == "workflow": - @reg.workflow - def test_func(): + with pytest.raises(KeyError): + reg.get_activity("nonexistent") + + def test_duplicate_workflow_registration_error(self): + """Test KeyError is raised for duplicate workflow registrations.""" + reg = Registry() + + @reg.workflow(name="duplicate_test") + class TestWorkflow: + @workflow.run + async def run(self): return "test" - - with pytest.raises(KeyError): - @reg.workflow - def test_func(): + + with pytest.raises(KeyError): + @reg.workflow(name="duplicate_test") + class TestWorkflow2: + @workflow.run + async def run(self): return "duplicate" - else: + + def test_duplicate_activity_registration_error(self): + """Test KeyError is raised for duplicate activity registrations.""" + reg = Registry() + + @reg.activity(name="test_func") + def test_func(): + return "test" + + with pytest.raises(KeyError): @reg.activity(name="test_func") def test_func(): - return "test" - - with pytest.raises(KeyError): - @reg.activity(name="test_func") - def test_func(): - return "duplicate" + return "duplicate" def test_register_activities_instance(self): reg = Registry() @@ -150,3 +165,40 @@ def test_of(self): assert result.get_activity("simple_fn") is not None assert result.get_activity("echo") is not None assert result.get_activity("async_fn") is not None + + def test_class_workflow_validation_errors(self): + """Test validation errors for class-based workflows.""" + reg = Registry() + + # Test missing run method + with pytest.raises(ValueError, match="No @workflow.run method found"): + @reg.workflow + class MissingRunWorkflow: + def some_method(self): + pass + + # Test duplicate run methods + with pytest.raises(ValueError, match="Multiple @workflow.run methods found"): + @reg.workflow + class DuplicateRunWorkflow: + @workflow.run + async def run1(self): + pass + + @workflow.run + async def run2(self): + pass + + def test_class_workflow_with_custom_name(self): + """Test class-based workflow with custom name.""" + reg = Registry() + + @reg.workflow(name="custom_workflow_name") + class CustomWorkflow: + @workflow.run + async def run(self, input: str) -> str: + return f"processed: {input}" + + workflow_def = reg.get_workflow("custom_workflow_name") + assert workflow_def.name == "custom_workflow_name" + assert workflow_def.cls == CustomWorkflow diff --git a/tests/cadence/worker/test_task_handler_integration.py b/tests/cadence/worker/test_task_handler_integration.py index 8e6aef9..9d52c66 100644 --- a/tests/cadence/worker/test_task_handler_integration.py +++ b/tests/cadence/worker/test_task_handler_integration.py @@ -61,11 +61,17 @@ def sample_decision_task(self): @pytest.mark.asyncio async def test_full_task_handling_flow_success(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow from base handler through decision handler.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -86,11 +92,17 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_full_task_handling_flow_with_error(self, handler, sample_decision_task, mock_registry): """Test the complete task handling flow when an error occurs.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) @@ -110,11 +122,17 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_context_activation_integration(self, handler, sample_decision_task, mock_registry): """Test that context activation works correctly in the integration.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -144,11 +162,17 @@ def track_context_activation(): @pytest.mark.asyncio async def test_multiple_workflow_executions(self, handler, mock_registry): """Test handling multiple workflow executions creates new engines for each.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create multiple decision tasks for different workflows task1 = Mock(spec=PollForDecisionTaskResponse) @@ -194,11 +218,17 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_workflow_engine_creation_integration(self, handler, sample_decision_task, mock_registry): """Test workflow engine creation integration.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine mock_engine = Mock(spec=WorkflowEngine) @@ -218,11 +248,17 @@ def mock_workflow_func(input_data): @pytest.mark.asyncio async def test_error_handling_with_context_cleanup(self, handler, sample_decision_task, mock_registry): """Test that context cleanup happens even when errors occur.""" - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Mock workflow engine to raise an error mock_engine = Mock(spec=WorkflowEngine) @@ -255,11 +291,17 @@ async def test_concurrent_task_handling(self, handler, mock_registry): """Test handling multiple tasks concurrently.""" import asyncio - # Mock workflow function - def mock_workflow_func(input_data): - return f"processed: {input_data}" - - mock_registry.get_workflow.return_value = mock_workflow_func + # Create actual workflow definition + from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions, workflow + + class MockWorkflow: + @workflow.run + async def run(self): + return "test_result" + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(MockWorkflow, workflow_opts) + mock_registry.get_workflow.return_value = workflow_definition # Create multiple tasks tasks = []