Skip to content

Commit adb578b

Browse files
committed
Add WorkflowDefinition
Signed-off-by: Tim Li <[email protected]>
1 parent 7a49de6 commit adb578b

File tree

7 files changed

+239
-60
lines changed

7 files changed

+239
-60
lines changed

cadence/worker/_decision_task_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
7676
)
7777

7878
try:
79-
workflow_func = self._registry.get_workflow(workflow_type_name)
79+
workflow_definition = self._registry.get_workflow(workflow_type_name)
8080
except KeyError:
8181
logger.error(
8282
"Workflow type not found in registry",
@@ -105,7 +105,7 @@ async def _handle_task_implementation(self, task: PollForDecisionTaskResponse) -
105105
workflow_engine = WorkflowEngine(
106106
info=workflow_info,
107107
client=self._client,
108-
workflow_func=workflow_func
108+
workflow_func=workflow_definition.fn
109109
)
110110
self._workflow_engines[cache_key] = workflow_engine
111111

cadence/worker/_registry.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload
1111
from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T
12+
from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -28,7 +29,7 @@ class Registry:
2829

2930
def __init__(self) -> None:
3031
"""Initialize the registry."""
31-
self._workflows: Dict[str, Callable] = {}
32+
self._workflows: Dict[str, WorkflowDefinition] = {}
3233
self._activities: Dict[str, ActivityDefinition] = {}
3334
self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping
3435

@@ -60,7 +61,10 @@ def decorator(f: Callable) -> Callable:
6061
if workflow_name in self._workflows:
6162
raise KeyError(f"Workflow '{workflow_name}' is already registered")
6263

63-
self._workflows[workflow_name] = f
64+
# Create WorkflowDefinition with type information
65+
workflow_opts = WorkflowDefinitionOptions(name=workflow_name)
66+
workflow_def = WorkflowDefinition.wrap(f, workflow_opts)
67+
self._workflows[workflow_name] = workflow_def
6468

6569
# Register alias if provided
6670
alias = options.get('alias')
@@ -135,15 +139,15 @@ def _register_activity(self, defn: ActivityDefinition) -> None:
135139
self._activities[defn.name] = defn
136140

137141

138-
def get_workflow(self, name: str) -> Callable:
142+
def get_workflow(self, name: str) -> WorkflowDefinition:
139143
"""
140144
Get a registered workflow by name.
141145
142146
Args:
143147
name: Name or alias of the workflow
144148
145149
Returns:
146-
The workflow function
150+
The workflow definition with type information
147151
148152
Raises:
149153
KeyError: If workflow is not found

cadence/workflow.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,147 @@
22
from contextlib import contextmanager
33
from contextvars import ContextVar
44
from dataclasses import dataclass
5-
from typing import Iterator
5+
from functools import update_wrapper
6+
from inspect import signature, Parameter
7+
from typing import Iterator, Callable, TypeVar, ParamSpec, Generic, TypedDict, Unpack, overload, get_type_hints, Type, Any
68

79
from cadence.client import Client
810

11+
12+
@dataclass(frozen=True)
13+
class WorkflowParameter:
14+
"""Parameter information for a workflow function."""
15+
name: str
16+
type_hint: Type | None
17+
default_value: Any | None
18+
19+
20+
class WorkflowDefinitionOptions(TypedDict, total=False):
21+
"""Options for defining a workflow."""
22+
name: str
23+
24+
25+
P = ParamSpec('P')
26+
T = TypeVar('T')
27+
28+
29+
class WorkflowDefinition(Generic[P, T]):
30+
"""
31+
Definition of a workflow function with metadata.
32+
33+
Similar to ActivityDefinition but for workflows.
34+
Provides type safety and metadata for workflow functions.
35+
"""
36+
37+
def __init__(self, wrapped: Callable[P, T], name: str, params: list[WorkflowParameter]):
38+
self._wrapped = wrapped
39+
self._name = name
40+
self._params = params
41+
update_wrapper(self, wrapped)
42+
43+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
44+
return self._wrapped(*args, **kwargs)
45+
46+
@property
47+
def name(self) -> str:
48+
"""Get the workflow name."""
49+
return self._name
50+
51+
@property
52+
def params(self) -> list[WorkflowParameter]:
53+
"""Get the workflow parameters."""
54+
return self._params
55+
56+
@property
57+
def fn(self) -> Callable[P, T]:
58+
"""Get the underlying workflow function."""
59+
return self._wrapped
60+
61+
@classmethod
62+
def wrap(cls, fn: Callable[P, T], opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition[P, T]':
63+
"""
64+
Wrap a function as a WorkflowDefinition.
65+
66+
Args:
67+
fn: The workflow function to wrap
68+
opts: Options for the workflow definition
69+
70+
Returns:
71+
A WorkflowDefinition instance
72+
"""
73+
name = fn.__qualname__
74+
if "name" in opts and opts["name"]:
75+
name = opts["name"]
76+
77+
params = _get_workflow_params(fn)
78+
return cls(fn, name, params)
79+
80+
81+
WorkflowDecorator = Callable[[Callable[P, T]], WorkflowDefinition[P, T]]
82+
83+
84+
@overload
85+
def defn(fn: Callable[P, T]) -> WorkflowDefinition[P, T]:
86+
...
87+
88+
89+
@overload
90+
def defn(**kwargs: Unpack[WorkflowDefinitionOptions]) -> WorkflowDecorator:
91+
...
92+
93+
94+
def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[WorkflowDefinitionOptions]) -> WorkflowDecorator | WorkflowDefinition[P, T]:
95+
"""
96+
Decorator to define a workflow function.
97+
98+
Usage:
99+
@defn
100+
def my_workflow(input_data: str) -> str:
101+
return f"processed: {input_data}"
102+
103+
@defn(name="custom_workflow_name")
104+
def my_other_workflow(input_data: str) -> str:
105+
return f"custom: {input_data}"
106+
107+
Args:
108+
fn: The workflow function (when used without parentheses)
109+
**kwargs: Workflow definition options
110+
111+
Returns:
112+
Either a WorkflowDefinition (direct decoration) or a decorator function
113+
"""
114+
opts = WorkflowDefinitionOptions(**kwargs)
115+
116+
def decorator(inner_fn: Callable[P, T]) -> WorkflowDefinition[P, T]:
117+
return WorkflowDefinition.wrap(inner_fn, opts)
118+
119+
if fn is not None:
120+
return decorator(fn)
121+
122+
return decorator
123+
124+
125+
def _get_workflow_params(fn: Callable) -> list[WorkflowParameter]:
126+
"""Extract parameter information from a workflow function."""
127+
args = signature(fn).parameters
128+
hints = get_type_hints(fn)
129+
result = []
130+
for name, param in args.items():
131+
# Filter out self parameter
132+
if param.name == "self":
133+
continue
134+
default = None
135+
if param.default != Parameter.empty:
136+
default = param.default
137+
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
138+
type_hint = hints.get(name, None)
139+
result.append(WorkflowParameter(name, type_hint, default))
140+
else:
141+
raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid")
142+
143+
return result
144+
145+
9146
@dataclass
10147
class WorkflowInfo:
11148
workflow_type: str

tests/cadence/worker/test_decision_task_handler.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,14 @@ def test_initialization(self, mock_client, mock_registry):
8282
@pytest.mark.asyncio
8383
async def test_handle_task_implementation_success(self, handler, sample_decision_task, mock_registry):
8484
"""Test successful decision task handling."""
85-
# Mock workflow function
86-
mock_workflow_func = Mock()
87-
mock_registry.get_workflow.return_value = mock_workflow_func
85+
# Create actual workflow definition
86+
def mock_workflow_func():
87+
return "test_result"
88+
89+
from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions
90+
workflow_opts = WorkflowDefinitionOptions(name="test_workflow")
91+
workflow_definition = WorkflowDefinition.wrap(mock_workflow_func, workflow_opts)
92+
mock_registry.get_workflow.return_value = workflow_definition
8893

8994
# Mock workflow engine
9095
mock_engine = Mock(spec=WorkflowEngine)
@@ -142,9 +147,14 @@ async def test_handle_task_implementation_workflow_not_found(self, handler, samp
142147
@pytest.mark.asyncio
143148
async def test_handle_task_implementation_caches_engines(self, handler, sample_decision_task, mock_registry):
144149
"""Test that decision task handler caches workflow engines for same workflow execution."""
145-
# Mock workflow function
146-
mock_workflow_func = Mock()
147-
mock_registry.get_workflow.return_value = mock_workflow_func
150+
# Create actual workflow definition
151+
def mock_workflow_func():
152+
return "test_result"
153+
154+
from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions
155+
workflow_opts = WorkflowDefinitionOptions(name="test_workflow")
156+
workflow_definition = WorkflowDefinition.wrap(mock_workflow_func, workflow_opts)
157+
mock_registry.get_workflow.return_value = workflow_definition
148158

149159
# Mock workflow engine
150160
mock_engine = Mock(spec=WorkflowEngine)
@@ -324,18 +334,20 @@ async def test_respond_decision_task_completed_error(self, handler, sample_decis
324334
async def test_workflow_engine_creation_with_workflow_info(self, handler, sample_decision_task, mock_registry):
325335
"""Test that WorkflowEngine is created with correct WorkflowInfo."""
326336
mock_workflow_func = Mock()
327-
mock_registry.get_workflow.return_value = mock_workflow_func
328-
337+
mock_workflow_definition = Mock()
338+
mock_workflow_definition.fn = mock_workflow_func
339+
mock_registry.get_workflow.return_value = mock_workflow_definition
340+
329341
mock_engine = Mock(spec=WorkflowEngine)
330342
mock_engine._is_workflow_complete = False # Add missing attribute
331343
mock_decision_result = Mock(spec=DecisionResult)
332344
mock_decision_result.decisions = []
333345
mock_engine.process_decision = AsyncMock(return_value=mock_decision_result)
334-
346+
335347
with patch('cadence.worker._decision_task_handler.WorkflowEngine', return_value=mock_engine) as mock_workflow_engine_class:
336348
with patch('cadence.worker._decision_task_handler.WorkflowInfo') as mock_workflow_info_class:
337349
await handler._handle_task_implementation(sample_decision_task)
338-
350+
339351
# Verify WorkflowInfo was created with correct parameters (called once for engine)
340352
assert mock_workflow_info_class.call_count == 1
341353
for call in mock_workflow_info_class.call_args_list:
@@ -345,7 +357,7 @@ async def test_workflow_engine_creation_with_workflow_info(self, handler, sample
345357
'workflow_id': "test_workflow_id",
346358
'workflow_run_id': "test_run_id"
347359
}
348-
360+
349361
# Verify WorkflowEngine was created with correct parameters
350362
mock_workflow_engine_class.assert_called_once()
351363
call_args = mock_workflow_engine_class.call_args

tests/cadence/worker/test_decision_task_handler_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def mock_client(self):
3535
def registry(self):
3636
"""Create a registry with a test workflow."""
3737
reg = Registry()
38-
39-
@reg.workflow
38+
39+
@reg.workflow(name="test_workflow")
4040
def test_workflow(input_data):
4141
"""Simple test workflow that returns the input."""
4242
return f"processed: {input_data}"
43-
43+
4444
return reg
4545

4646
@pytest.fixture

tests/cadence/worker/test_registry.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,21 @@ def test_basic_registration_and_retrieval(self, registration_type):
3030
@reg.workflow
3131
def test_func():
3232
return "test"
33-
34-
func = reg.get_workflow("test_func")
33+
34+
# Registry stores WorkflowDefinition internally
35+
func_def = reg.get_workflow(test_func.__name__)
36+
# WorkflowDefinition can be called directly
37+
assert func_def() == "test"
38+
# Verify it's actually a WorkflowDefinition
39+
from cadence.workflow import WorkflowDefinition
40+
assert isinstance(func_def, WorkflowDefinition)
3541
else:
3642
@reg.activity
3743
def test_func():
3844
return "test"
3945

4046
func = reg.get_activity(test_func.name)
41-
42-
assert func() == "test"
47+
assert func() == "test"
4348

4449
def test_direct_call_behavior(self):
4550
reg = Registry()

0 commit comments

Comments
 (0)