diff --git a/cadence/_internal/activity/_activity_executor.py b/cadence/_internal/activity/_activity_executor.py index f9efba0..e37e736 100644 --- a/cadence/_internal/activity/_activity_executor.py +++ b/cadence/_internal/activity/_activity_executor.py @@ -1,4 +1,3 @@ -import inspect from concurrent.futures import ThreadPoolExecutor from logging import getLogger from traceback import format_exception @@ -7,7 +6,7 @@ from google.protobuf.timestamp import to_datetime from cadence._internal.activity._context import _Context, _SyncContext -from cadence.activity import ActivityInfo +from cadence.activity import ActivityInfo, ActivityDefinition, ExecutionStrategy from cadence.api.v1.common_pb2 import Failure from cadence.api.v1.service_worker_pb2 import PollForActivityTaskResponse, RespondActivityTaskFailedRequest, \ RespondActivityTaskCompletedRequest @@ -16,7 +15,7 @@ _logger = getLogger(__name__) class ActivityExecutor: - def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], Callable]): + def __init__(self, client: Client, task_list: str, identity: str, max_workers: int, registry: Callable[[str], ActivityDefinition]): self._client = client self._data_converter = client.data_converter self._registry = registry @@ -36,16 +35,16 @@ async def execute(self, task: PollForActivityTaskResponse): def _create_context(self, task: PollForActivityTaskResponse) -> _Context: activity_type = task.activity_type.name try: - activity_fn = self._registry(activity_type) + activity_def = self._registry(activity_type) except KeyError: raise KeyError(f"Activity type not found: {activity_type}") from None info = self._create_info(task) - if inspect.iscoroutinefunction(activity_fn): - return _Context(self._client, info, activity_fn) + if activity_def.strategy == ExecutionStrategy.ASYNC: + return _Context(self._client, info, activity_def) else: - return _SyncContext(self._client, info, activity_fn, self._thread_pool) + return _SyncContext(self._client, info, activity_def, self._thread_pool) async def _report_failure(self, task: PollForActivityTaskResponse, error: Exception): try: diff --git a/cadence/_internal/activity/_context.py b/cadence/_internal/activity/_context.py index 208b859..ce2f94b 100644 --- a/cadence/_internal/activity/_context.py +++ b/cadence/_internal/activity/_context.py @@ -1,15 +1,14 @@ import asyncio from concurrent.futures.thread import ThreadPoolExecutor -from typing import Callable, Any +from typing import Any from cadence import Client -from cadence._internal.type_utils import get_fn_parameters -from cadence.activity import ActivityInfo, ActivityContext +from cadence.activity import ActivityInfo, ActivityContext, ActivityDefinition from cadence.api.v1.common_pb2 import Payload class _Context(ActivityContext): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any]): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any]): self._client = client self._info = info self._activity_fn = activity_fn @@ -20,7 +19,7 @@ async def execute(self, payload: Payload) -> Any: return await self._activity_fn(*params) async def _to_params(self, payload: Payload) -> list[Any]: - type_hints = get_fn_parameters(self._activity_fn) + type_hints = [param.type_hint for param in self._activity_fn.params] return await self._client.data_converter.from_data(payload, type_hints) def client(self) -> Client: @@ -30,7 +29,7 @@ def info(self) -> ActivityInfo: return self._info class _SyncContext(_Context): - def __init__(self, client: Client, info: ActivityInfo, activity_fn: Callable[[Any], Any], executor: ThreadPoolExecutor): + def __init__(self, client: Client, info: ActivityInfo, activity_fn: ActivityDefinition[[Any], Any], executor: ThreadPoolExecutor): super().__init__(client, info, activity_fn) self._executor = executor diff --git a/cadence/_internal/type_utils.py b/cadence/_internal/type_utils.py deleted file mode 100644 index 84fd07c..0000000 --- a/cadence/_internal/type_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from inspect import signature, Parameter -from typing import Callable, List, Type, get_type_hints - -def get_fn_parameters(fn: Callable) -> List[Type | None]: - args = signature(fn).parameters - hints = get_type_hints(fn) - result = [] - for name, param in args.items(): - if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - type_hint = hints.get(name, None) - result.append(type_hint) - - return result - -def validate_fn_parameters(fn: Callable) -> None: - args = signature(fn).parameters - for name, param in args.items(): - if param.kind not in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") \ No newline at end of file diff --git a/cadence/activity.py b/cadence/activity.py index 0f71fb0..57a9b48 100644 --- a/cadence/activity.py +++ b/cadence/activity.py @@ -1,9 +1,14 @@ +import inspect from abc import ABC, abstractmethod from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta, datetime -from typing import Iterator +from enum import Enum +from functools import update_wrapper +from inspect import signature, Parameter +from typing import Iterator, TypedDict, Unpack, Callable, Type, ParamSpec, TypeVar, Generic, get_type_hints, \ + Any, overload from cadence import Client @@ -59,3 +64,99 @@ def is_set() -> bool: @staticmethod def get() -> 'ActivityContext': return ActivityContext._var.get() + + +@dataclass(frozen=True) +class ActivityParameter: + name: str + type_hint: Type | None + default_value: Any | None + +class ExecutionStrategy(Enum): + ASYNC = "async" + THREAD_POOL = "thread_pool" + +class ActivityDefinitionOptions(TypedDict, total=False): + name: str + +P = ParamSpec('P') +T = TypeVar('T') + +class ActivityDefinition(Generic[P, T]): + def __init__(self, wrapped: Callable[P, T], name: str, strategy: ExecutionStrategy, params: list[ActivityParameter]): + self._wrapped = wrapped + self._name = name + self._strategy = strategy + self._params = params + update_wrapper(self, wrapped) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + return self._wrapped(*args, **kwargs) + + @property + def name(self) -> str: + return self._name + + @property + def strategy(self) -> ExecutionStrategy: + return self._strategy + + @property + def params(self) -> list[ActivityParameter]: + return self._params + + @staticmethod + def wrap(fn: Callable[P, T], opts: ActivityDefinitionOptions) -> 'ActivityDefinition[P, T]': + name = fn.__qualname__ + if "name" in opts and opts["name"]: + name = opts["name"] + + strategy = ExecutionStrategy.THREAD_POOL + if inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__): # type: ignore + strategy = ExecutionStrategy.ASYNC + + params = _get_params(fn) + return ActivityDefinition(fn, name, strategy, params) + + +ActivityDecorator = Callable[[Callable[P, T]], ActivityDefinition[P, T]] + +@overload +def defn(fn: Callable[P, T]) -> ActivityDefinition[P, T]: + ... + +@overload +def defn(**kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: + ... + +def defn(fn: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: + options = ActivityDefinitionOptions(**kwargs) + def decorator(inner_fn: Callable[P, T]) -> ActivityDefinition[P, T]: + return ActivityDefinition.wrap(inner_fn, options) + + if fn is not None: + return decorator(fn) + + return decorator + + +def _get_params(fn: Callable) -> list[ActivityParameter]: + args = signature(fn).parameters + hints = get_type_hints(fn) + result = [] + for name, param in args.items(): + # "unbound functions" aren't a thing in the Python spec. Filter out the self parameter and hope they followed + # the convention. + if param.name == "self": + continue + default = None + if param.default != Parameter.empty: + default = param.default + if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): + type_hint = hints.get(name, None) + result.append(ActivityParameter(name, type_hint, default)) + + else: + raise ValueError(f"Parameters must be positional. {name} is {param.kind}, and not valid") + + return result diff --git a/cadence/worker/__init__.py b/cadence/worker/__init__.py index 6249d28..4084e9a 100644 --- a/cadence/worker/__init__.py +++ b/cadence/worker/__init__.py @@ -8,7 +8,6 @@ from ._registry import ( Registry, RegisterWorkflowOptions, - RegisterActivityOptions, ) __all__ = [ @@ -16,5 +15,4 @@ "WorkerOptions", 'Registry', 'RegisterWorkflowOptions', - 'RegisterActivityOptions', ] diff --git a/cadence/worker/_registry.py b/cadence/worker/_registry.py index 1f5d03f..d60521d 100644 --- a/cadence/worker/_registry.py +++ b/cadence/worker/_registry.py @@ -7,9 +7,8 @@ """ import logging -from typing import Callable, Dict, Optional, Unpack, TypedDict -from cadence._internal.type_utils import validate_fn_parameters - +from typing import Callable, Dict, Optional, Unpack, TypedDict, Sequence, overload +from cadence.activity import ActivityDefinitionOptions, ActivityDefinition, ActivityDecorator, P, T logger = logging.getLogger(__name__) @@ -19,13 +18,6 @@ class RegisterWorkflowOptions(TypedDict, total=False): name: Optional[str] alias: Optional[str] - -class RegisterActivityOptions(TypedDict, total=False): - """Options for registering an activity.""" - name: Optional[str] - alias: Optional[str] - - class Registry: """ Registry for managing workflows and activities. @@ -37,10 +29,9 @@ class Registry: def __init__(self) -> None: """Initialize the registry.""" self._workflows: Dict[str, Callable] = {} - self._activities: Dict[str, Callable] = {} + self._activities: Dict[str, ActivityDefinition] = {} self._workflow_aliases: Dict[str, str] = {} # alias -> name mapping - self._activity_aliases: Dict[str, str] = {} # alias -> name mapping - + def workflow( self, func: Optional[Callable] = None, @@ -84,12 +75,16 @@ def decorator(f: Callable) -> Callable: if func is None: return decorator return decorator(func) + + @overload + def activity(self, func: Callable[P, T]) -> ActivityDefinition[P, T]: + ... + + @overload + def activity(self, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator: + ... - def activity( - self, - func: Optional[Callable] = None, - **kwargs: Unpack[RegisterActivityOptions] - ) -> Callable: + def activity(self, func: Callable[P, T] | None = None, **kwargs: Unpack[ActivityDefinitionOptions]) -> ActivityDecorator | ActivityDefinition[P, T]: """ Register an activity function. @@ -105,30 +100,40 @@ def activity( Raises: KeyError: If activity name already exists """ - options = RegisterActivityOptions(**kwargs) - - def decorator(f: Callable) -> Callable: - validate_fn_parameters(f) - activity_name = options.get('name') or f.__name__ - - if activity_name in self._activities: - raise KeyError(f"Activity '{activity_name}' is already registered") - - self._activities[activity_name] = f - - # Register alias if provided - alias = options.get('alias') - if alias: - if alias in self._activity_aliases: - raise KeyError(f"Activity alias '{alias}' is already registered") - self._activity_aliases[alias] = activity_name - - logger.info(f"Registered activity '{activity_name}'") - return f - - if func is None: - return decorator - return decorator(func) + options = ActivityDefinitionOptions(**kwargs) + + def decorator(f: Callable[P, T]) -> ActivityDefinition[P, T]: + defn = ActivityDefinition.wrap(f, options) + + self._register_activity(defn) + + return defn + + if func is not None: + return decorator(func) + + return decorator + + def register_activities(self, obj: object) -> None: + activities = _find_activity_definitions(obj) + if not activities: + raise ValueError(f"No activity definitions found in '{repr(obj)}'") + + for defn in activities: + self._register_activity(defn) + + + def register_activity(self, defn: Callable) -> None: + if not isinstance(defn, ActivityDefinition): + raise ValueError(f"{defn.__qualname__} must have @activity.defn decorator") + self._register_activity(defn) + + def _register_activity(self, defn: ActivityDefinition) -> None: + if defn.name in self._activities: + raise KeyError(f"Activity '{defn.name}' is already registered") + + self._activities[defn.name] = defn + def get_workflow(self, name: str) -> Callable: """ @@ -151,7 +156,7 @@ def get_workflow(self, name: str) -> Callable: return self._workflows[actual_name] - def get_activity(self, name: str) -> Callable: + def get_activity(self, name: str) -> ActivityDefinition: """ Get a registered activity by name. @@ -164,13 +169,45 @@ def get_activity(self, name: str) -> Callable: Raises: KeyError: If activity is not found """ - # Check if it's an alias - actual_name = self._activity_aliases.get(name, name) - - if actual_name not in self._activities: - raise KeyError(f"Activity '{name}' not found in registry") - - return self._activities[actual_name] - + return self._activities[name] + + def __add__(self, other: 'Registry') -> 'Registry': + result = Registry() + for name, fn in self._activities.items(): + result._register_activity(fn) + for name, fn in other._activities.items(): + result._register_activity(fn) + + return result + + @staticmethod + def of(*args: 'Registry') -> 'Registry': + result = Registry() + for other in args: + result += other + + return result + +def _find_activity_definitions(instance: object) -> Sequence[ActivityDefinition]: + attr_to_def = {} + for t in instance.__class__.__mro__: + for attr in dir(t): + if attr.startswith("_"): + continue + value = getattr(t, attr) + if isinstance(value, ActivityDefinition): + if attr in attr_to_def: + raise ValueError(f"'{attr}' was overridden with a duplicate activity definition") + attr_to_def[attr] = value + + # Create new definitions, copying the attributes from the declaring type but using the function + # from the specific object. This allows for the decorator to be applied to the base class and the + # function to be overridden + result = [] + for attr, definition in attr_to_def.items(): + result.append(ActivityDefinition(getattr(instance, attr), definition.name, definition.strategy, definition.params)) + + return result + \ No newline at end of file diff --git a/tests/cadence/_internal/activity/test_activity_executor.py b/tests/cadence/_internal/activity/test_activity_executor.py index 89b95e2..d6aba4d 100644 --- a/tests/cadence/_internal/activity/test_activity_executor.py +++ b/tests/cadence/_internal/activity/test_activity_executor.py @@ -8,11 +8,12 @@ from cadence import activity, Client from cadence._internal.activity import ActivityExecutor -from cadence.activity import ActivityInfo +from cadence.activity import ActivityInfo, ActivityDefinition from cadence.api.v1.common_pb2 import WorkflowExecution, ActivityType, Payload, Failure, WorkflowType from cadence.api.v1.service_worker_pb2 import RespondActivityTaskCompletedResponse, PollForActivityTaskResponse, \ RespondActivityTaskCompletedRequest, RespondActivityTaskFailedResponse, RespondActivityTaskFailedRequest from cadence.data_converter import DefaultDataConverter +from cadence.worker import Registry @pytest.fixture @@ -27,12 +28,14 @@ async def test_activity_async_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -44,12 +47,14 @@ async def test_activity_async_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() @@ -70,12 +75,14 @@ async def test_activity_args(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(first: str, second: str): return " ".join([first, second]) - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", '["hello", "world"]')) + await executor.execute(fake_task("activity_type", '["hello", "world"]')) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -87,6 +94,8 @@ async def test_activity_sync_success(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): try: asyncio.get_running_loop() @@ -94,9 +103,9 @@ def activity_fn(): return "success" raise RuntimeError("expected to be running outside of the event loop") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskCompleted.assert_called_once_with(RespondActivityTaskCompletedRequest( task_token=b'task_token', @@ -107,13 +116,14 @@ def activity_fn(): async def test_activity_sync_failure(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) - + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): raise KeyError("failure") - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() @@ -134,18 +144,18 @@ async def test_activity_unknown(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskFailed = AsyncMock(return_value=RespondActivityTaskFailedResponse()) - def registry(name: str): + def registry(name: str) -> ActivityDefinition: raise KeyError(f"unknown activity: {name}") executor = ActivityExecutor(client, 'task_list', 'identity', 1, registry) - await executor.execute(fake_task("any", "")) + await executor.execute(fake_task("activity_type", "")) worker_stub.RespondActivityTaskFailed.assert_called_once() call = worker_stub.RespondActivityTaskFailed.call_args[0][0] - assert 'Activity type not found: any' in call.failure.details.decode() + assert 'Activity type not found: activity_type' in call.failure.details.decode() call.failure.details = bytes() assert call == RespondActivityTaskFailedRequest( task_token=b'task_token', @@ -158,14 +168,15 @@ def registry(name: str): async def test_activity_context(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) - + reg = Registry() + @reg.activity(name="activity_type") async def activity_fn(): assert fake_info("activity_type") == activity.info() assert activity.in_activity() assert activity.client() is not None return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) @@ -179,6 +190,8 @@ async def test_activity_context_sync(client): worker_stub = client.worker_stub worker_stub.RespondActivityTaskCompleted = AsyncMock(return_value=RespondActivityTaskCompletedResponse()) + reg = Registry() + @reg.activity(name="activity_type") def activity_fn(): assert fake_info("activity_type") == activity.info() assert activity.in_activity() @@ -186,7 +199,7 @@ def activity_fn(): activity.client() return "success" - executor = ActivityExecutor(client, 'task_list', 'identity', 1, lambda name: activity_fn) + executor = ActivityExecutor(client, 'task_list', 'identity', 1, reg.get_activity) await executor.execute(fake_task("activity_type", "")) diff --git a/tests/cadence/_internal/test_type_utils.py b/tests/cadence/_internal/test_type_utils.py deleted file mode 100644 index 9e35e81..0000000 --- a/tests/cadence/_internal/test_type_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Callable, Type - -import pytest - -from cadence._internal.type_utils import get_fn_parameters, validate_fn_parameters - - -def _single_param(name: str): - ... - -def _multiple_param(name: str, other: 'str'): - ... - -def _with_args(name:str, *args): - ... - -def _with_kwargs(name:str, **kwargs): - ... - -def _strictly_positional(name: str, other: str, *args, **kwargs): - ... - -def _keyword_only(*args, foo: str): - ... - - -@pytest.mark.parametrize( - "fn,expected", - [ - pytest.param( - _single_param, [str], id="single param" - ), - pytest.param( - _multiple_param, [str, str], id="multiple param" - ), - pytest.param( - _strictly_positional, [str, str], id="strictly positional" - ), - pytest.param( - _keyword_only, [], id="keyword only" - ), - ] -) -def test_get_fn_parameters(fn: Callable, expected: list[Type]): - params = get_fn_parameters(fn) - assert params == expected - -@pytest.mark.parametrize( - "fn,expected", - [ - pytest.param( - _single_param, None, id="single param" - ), - pytest.param( - _multiple_param, None, id="multiple param" - ), - pytest.param( - _with_args, ValueError, id="with args" - ), - pytest.param( - _with_kwargs, ValueError, id="with kwargs" - ), - ] -) -def test_validate_fn_parameters(fn: Callable, expected: Type[Exception]): - if expected: - with pytest.raises(expected): - validate_fn_parameters(fn) - else: - validate_fn_parameters(fn) \ No newline at end of file diff --git a/tests/cadence/common_activities.py b/tests/cadence/common_activities.py new file mode 100644 index 0000000..be78c62 --- /dev/null +++ b/tests/cadence/common_activities.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass + +from cadence import activity + + +@activity.defn() +def simple_fn() -> None: + pass + +@activity.defn +def no_parens() -> None: + pass + +@activity.defn() +def echo(incoming: str) -> str: + return incoming + +@activity.defn(name="renamed") +def renamed_fn() -> None: + pass + +@activity.defn() +async def async_fn() -> None: + pass + +class Activities: + + @activity.defn() + def echo_sync(self, incoming: str) -> str: + return incoming + + @activity.defn() + async def echo_async(self, incoming: str) -> str: + return incoming + +class ActivityInterface: + @activity.defn() + def do_something(self) -> str: + ... + +@dataclass +class ActivityImpl(ActivityInterface): + result: str + + def do_something(self) -> str: + return self.result + +class InvalidImpl(ActivityInterface): + @activity.defn(name="something else entirely") + def do_something(self) -> str: + return "hehe" \ No newline at end of file diff --git a/tests/cadence/worker/test_registry.py b/tests/cadence/worker/test_registry.py index 57f345b..4a8973b 100644 --- a/tests/cadence/worker/test_registry.py +++ b/tests/cadence/worker/test_registry.py @@ -5,7 +5,9 @@ import pytest -from cadence.worker import Registry, RegisterWorkflowOptions, RegisterActivityOptions +from cadence import activity +from cadence.worker import Registry +from tests.cadence import common_activities class TestRegistry: @@ -35,74 +37,22 @@ def test_func(): def test_func(): return "test" - func = reg.get_activity("test_func") + func = reg.get_activity(test_func.name) assert func() == "test" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_direct_call_behavior(self, registration_type): - """Test direct function call behavior for both workflows and activities.""" + def test_direct_call_behavior(self): reg = Registry() - + + @activity.defn(name="test_func") def test_func(): return "direct_call" + + reg.register_activity(test_func) + func = reg.get_activity("test_func") - if registration_type == "workflow": - registered_func = reg.workflow(test_func) - func = reg.get_workflow("test_func") - else: - registered_func = reg.activity(test_func) - func = reg.get_activity("test_func") - - assert registered_func == test_func assert func() == "direct_call" - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_decorator_with_options(self, registration_type): - """Test decorator with options for both workflows and activities.""" - reg = Registry() - - if registration_type == "workflow": - @reg.workflow(name="custom_name", alias="custom_alias") - def test_func(): - return "decorator_with_options" - - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - @reg.activity(name="custom_name", alias="custom_alias") - def test_func(): - return "decorator_with_options" - - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_alias") - - assert func() == "decorator_with_options" - assert func_by_alias() == "decorator_with_options" - assert func == func_by_alias - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_direct_call_with_options(self, registration_type): - """Test direct call with options for both workflows and activities.""" - reg = Registry() - - def test_func(): - return "direct_call_with_options" - - if registration_type == "workflow": - registered_func = reg.workflow(test_func, name="custom_name", alias="custom_alias") - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - registered_func = reg.activity(test_func, name="custom_name", alias="custom_alias") - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_name") - - assert registered_func == test_func - assert func() == "direct_call_with_options" - assert func_by_alias() == "direct_call_with_options" - assert func == func_by_alias - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) def test_not_found_error(self, registration_type): """Test KeyError is raised when function not found.""" @@ -130,62 +80,73 @@ def test_func(): def test_func(): return "duplicate" else: - @reg.activity + @reg.activity(name="test_func") def test_func(): return "test" - + with pytest.raises(KeyError): - @reg.activity + @reg.activity(name="test_func") def test_func(): return "duplicate" - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_alias_functionality(self, registration_type): - """Test alias functionality for both workflows and activities.""" + + def test_register_activities_instance(self): reg = Registry() - - if registration_type == "workflow": - @reg.workflow(name="custom_name") - def test_func(): - return "test" - - func = reg.get_workflow("custom_name") - else: - @reg.activity(alias="custom_alias") - def test_func(): - return "test" - - func = reg.get_activity("custom_alias") - func_by_name = reg.get_activity("test_func") - assert func_by_name() == "test" - assert func == func_by_name - - assert func() == "test" - - @pytest.mark.parametrize("registration_type", ["workflow", "activity"]) - def test_options_class(self, registration_type): - """Test using options classes for both workflows and activities.""" + + reg.register_activities(common_activities.Activities()) + + assert reg.get_activity("Activities.echo_sync") is not None + assert reg.get_activity("Activities.echo_sync") is not None + + def test_register_activities_interface(self): + impl = common_activities.ActivityImpl("result") reg = Registry() - - if registration_type == "workflow": - options = RegisterWorkflowOptions(name="custom_name", alias="custom_alias") - - @reg.workflow(**options) - def test_func(): - return "test" - - func = reg.get_workflow("custom_name") - func_by_alias = reg.get_workflow("custom_alias") - else: - options = RegisterActivityOptions(name="custom_name", alias="custom_alias") - - @reg.activity(**options) - def test_func(): - return "test" - - func = reg.get_activity("custom_name") - func_by_alias = reg.get_activity("custom_alias") - - assert func() == "test" - assert func_by_alias() == "test" - assert func == func_by_alias + + reg.register_activities(impl) + + assert reg.get_activity(common_activities.ActivityInterface.do_something.name) is not None + assert reg.get_activity("ActivityInterface.do_something") is not None + assert reg.get_activity(common_activities.ActivityInterface.do_something.name)() == "result" + + def test_register_activities_invalid_impl(self): + impl = common_activities.InvalidImpl() + reg = Registry() + + with pytest.raises(ValueError): + reg.register_activities(impl) + + + def test_add(self): + registry = Registry() + registry.register_activity(common_activities.simple_fn) + other = Registry() + other.register_activity(common_activities.echo) + + result = registry + other + + assert result.get_activity("simple_fn") is not None + assert result.get_activity("echo") is not None + with pytest.raises(KeyError): + registry.get_activity("echo") + with pytest.raises(KeyError): + other.get_activity("simple_fn") + + def test_add_duplicate(self): + registry = Registry() + registry.register_activity(common_activities.simple_fn) + other = Registry() + other.register_activity(common_activities.simple_fn) + with pytest.raises(KeyError): + registry + other + + def test_of(self): + first = Registry() + second = Registry() + third = Registry() + first.register_activity(common_activities.simple_fn) + second.register_activity(common_activities.echo) + third.register_activity(common_activities.async_fn) + + result = Registry.of(first, second, third) + assert result.get_activity("simple_fn") is not None + assert result.get_activity("echo") is not None + assert result.get_activity("async_fn") is not None