Skip to content

Commit e361295

Browse files
authored
Introduce ActivityDefinition as a wrapper for Activity Fns (#29)
Signed-off-by: Nate Mortensen <[email protected]>
1 parent 016cc54 commit e361295

File tree

10 files changed

+356
-286
lines changed

10 files changed

+356
-286
lines changed

cadence/_internal/activity/_activity_executor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
from concurrent.futures import ThreadPoolExecutor
32
from logging import getLogger
43
from traceback import format_exception
@@ -7,7 +6,7 @@
76
from google.protobuf.timestamp import to_datetime
87

98
from cadence._internal.activity._context import _Context, _SyncContext
10-
from cadence.activity import ActivityInfo
9+
from cadence.activity import ActivityInfo, ActivityDefinition, ExecutionStrategy
1110
from cadence.api.v1.common_pb2 import Failure
1211
from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \
1312
RespondActivityTaskCompletedRequest
@@ -16,7 +15,7 @@
1615
_logger = getLogger(__name__)
1716

1817
class ActivityExecutor:
19-
def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]):
18+
def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], ActivityDefinition]):
2019
self._client = client
2120
self._data_converter = client.data_converter
2221
self._registry = registry
@@ -36,16 +35,16 @@ async def execute(self, task: PollForActivityTaskResponse):
3635
def _create_context(self, task: PollForActivityTaskResponse) -> _Context:
3736
activity_type = task.activity_type.name
3837
try:
39-
activity_fn = self._registry(activity_type)
38+
activity_def = self._registry(activity_type)
4039
except KeyError:
4140
raise KeyError(f"Activity type not found: {activity_type}") from None
4241

4342
info = self._create_info(task)
4443

45-
if inspect.iscoroutinefunction(activity_fn):
46-
return _Context(self._client, info, activity_fn)
44+
if activity_def.strategy == ExecutionStrategy.ASYNC:
45+
return _Context(self._client, info, activity_def)
4746
else:
48-
return _SyncContext(self._client, info, activity_fn, self._thread_pool)
47+
return _SyncContext(self._client, info, activity_def, self._thread_pool)
4948

5049
async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception):
5150
try:

cadence/_internal/activity/_context.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import asyncio
22
from concurrent.futures.thread import ThreadPoolExecutor
3-
from typing import Callable, Any
3+
from typing import Any
44

55
from cadence import Client
6-
from cadence._internal.type_utils import get_fn_parameters
7-
from cadence.activity import ActivityInfo, ActivityContext
6+
from cadence.activity import ActivityInfo, ActivityContext, ActivityDefinition
87
from cadence.api.v1.common_pb2 import Payload
98

109

1110
class _Context(ActivityContext):
12-
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]):
11+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any]):
1312
self._client = client
1413
self._info = info
1514
self._activity_fn = activity_fn
@@ -20,7 +19,7 @@ async def execute(self, payload: Payload) -> Any:
2019
return await self._activity_fn(*params)
2120

2221
async def _to_params(self, payload: Payload) -> list[Any]:
23-
type_hints = get_fn_parameters(self._activity_fn)
22+
type_hints = [param.type_hint for param in self._activity_fn.params]
2423
return await self._client.data_converter.from_data(payload, type_hints)
2524

2625
def client(self) -> Client:
@@ -30,7 +29,7 @@ def info(self) -> ActivityInfo:
3029
return self._info
3130

3231
class _SyncContext(_Context):
33-
def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor):
32+
def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any], executor: ThreadPoolExecutor):
3433
super().__init__(client, info, activity_fn)
3534
self._executor = executor
3635

cadence/_internal/type_utils.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

cadence/activity.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
import inspect
12
from abc import ABC, abstractmethod
23
from contextlib import contextmanager
34
from contextvars import ContextVar
45
from dataclasses import dataclass
56
from datetime import timedelta, datetime
6-
from typing import Iterator
7+
from enum import Enum
8+
from functools import update_wrapper
9+
from inspect import signature, Parameter
10+
from typing import Iterator, TypedDict, Unpack, Callable, Type, ParamSpec, TypeVar, Generic, get_type_hints, \
11+
Any, overload
712

813
from cadence import Client
914

@@ -59,3 +64,99 @@ def is_set() -> bool:
5964
@staticmethod
6065
def get() -> 'ActivityContext':
6166
return ActivityContext._var.get()
67+
68+
69+
@dataclass(frozen=True)
70+
class ActivityParameter:
71+
name: str
72+
type_hint: Type | None
73+
default_value: Any | None
74+
75+
class ExecutionStrategy(Enum):
76+
ASYNC = "async"
77+
THREAD_POOL = "thread_pool"
78+
79+
class ActivityDefinitionOptions(TypedDict, total=False):
80+
name: str
81+
82+
P = ParamSpec('P')
83+
T = TypeVar('T')
84+
85+
class ActivityDefinition(Generic[P, T]):
86+
def __init__(self, wrapped: Callable[P, T], name: str, strategy: ExecutionStrategy, params: list[ActivityParameter]):
87+
self._wrapped = wrapped
88+
self._name = name
89+
self._strategy = strategy
90+
self._params = params
91+
update_wrapper(self, wrapped)
92+
93+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
94+
return self._wrapped(*args, **kwargs)
95+
96+
@property
97+
def name(self) -> str:
98+
return self._name
99+
100+
@property
101+
def strategy(self) -> ExecutionStrategy:
102+
return self._strategy
103+
104+
@property
105+
def params(self) -> list[ActivityParameter]:
106+
return self._params
107+
108+
@staticmethod
109+
def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefinition[P, T]':
110+
name = fn.__qualname__
111+
if "name" in opts and opts["name"]:
112+
name = opts["name"]
113+
114+
strategy = ExecutionStrategy.THREAD_POOL
115+
if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore
116+
strategy = ExecutionStrategy.ASYNC
117+
118+
params = _get_params(fn)
119+
return ActivityDefinition(fn, name, strategy, params)
120+
121+
122+
ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]]
123+
124+
@overload
125+
def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]:
126+
...
127+
128+
@overload
129+
def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator:
130+
...
131+
132+
def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]:
133+
options = ActivityDefinitionOptions(**kwargs)
134+
def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]:
135+
return ActivityDefinition.wrap(inner_fn, options)
136+
137+
if fn is not None:
138+
return decorator(fn)
139+
140+
return decorator
141+
142+
143+
def _get_params(fn: Callable) -> list[ActivityParameter]:
144+
args = signature(fn).parameters
145+
hints = get_type_hints(fn)
146+
result = []
147+
for name, param in args.items():
148+
# "unbound functions" aren't a thing in the Python spec. Filter out the self parameter and hope they followed
149+
# the convention.
150+
if param.name == "self":
151+
continue
152+
default = None
153+
if param.default != Parameter.empty:
154+
default = param.default
155+
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
156+
type_hint = hints.get(name, None)
157+
result.append(ActivityParameter(name, type_hint, default))
158+
159+
else:
160+
raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid")
161+
162+
return result

cadence/worker/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from ._registry import (
99
Registry,
1010
RegisterWorkflowOptions,
11-
RegisterActivityOptions,
1211
)
1312

1413
__all__ = [
1514
"Worker",
1615
"WorkerOptions",
1716
'Registry',
1817
'RegisterWorkflowOptions',
19-
'RegisterActivityOptions',
2018
]

0 commit comments

Comments
 (0)