diff --git a/dapr/clients/grpc/client.py b/dapr/clients/grpc/client.py index a0a886d0..e84f000d 100644 --- a/dapr/clients/grpc/client.py +++ b/dapr/clients/grpc/client.py @@ -87,7 +87,7 @@ from dapr.clients.retry import RetryPolicy from dapr.common.pubsub.subscription import StreamCancelledError from dapr.conf import settings -from dapr.conf.helpers import GrpcEndpoint +from dapr.conf.helpers import GrpcEndpoint, build_grpc_channel_options from dapr.proto import api_service_v1, api_v1, common_v1 from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse from dapr.version import __version__ @@ -146,11 +146,9 @@ def __init__( useragent = f'dapr-sdk-python/{__version__}' if not max_grpc_message_length: - options = [ - ('grpc.primary_user_agent', useragent), - ] + base_options = [('grpc.primary_user_agent', useragent)] else: - options = [ + base_options = [ ('grpc.max_send_message_length', max_grpc_message_length), # type: ignore ('grpc.max_receive_message_length', max_grpc_message_length), # type: ignore ('grpc.primary_user_agent', useragent), @@ -166,6 +164,9 @@ def __init__( except ValueError as error: raise DaprInternalError(f'{error}') from error + # Merge standard + keepalive + retry options + options = build_grpc_channel_options(base_options) + if self._uri.tls: self._channel = grpc.secure_channel( # type: ignore self._uri.endpoint, diff --git a/dapr/conf/__init__.py b/dapr/conf/__init__.py index 7fbe5f2f..0959311a 100644 --- a/dapr/conf/__init__.py +++ b/dapr/conf/__init__.py @@ -24,9 +24,7 @@ def __init__(self): default_value = getattr(global_settings, setting) env_variable = os.environ.get(setting) if env_variable: - val = ( - type(default_value)(env_variable) if default_value is not None else env_variable - ) + val = self._coerce_env_value(default_value, env_variable) setattr(self, setting, val) else: setattr(self, setting, default_value) @@ -36,5 +34,27 @@ def __getattr__(self, name): raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") return getattr(self, name) + @staticmethod + def _coerce_env_value(default_value, env_variable: str): + if default_value is None: + return env_variable + # Handle booleans explicitly to avoid bool('false') == True + if isinstance(default_value, bool): + s = env_variable.strip().lower() + if s in ('1', 'true', 't', 'yes', 'y', 'on'): + return True + if s in ('0', 'false', 'f', 'no', 'n', 'off'): + return False + # Fallback: non-empty -> True for backward-compat + return bool(s) + # Integers + if isinstance(default_value, int) and not isinstance(default_value, bool): + return int(env_variable) + # Floats + if isinstance(default_value, float): + return float(env_variable) + # Other types: try to cast as before + return type(default_value)(env_variable) + settings = Settings() diff --git a/dapr/conf/global_settings.py b/dapr/conf/global_settings.py index 5a64e5d4..dd2bae86 100644 --- a/dapr/conf/global_settings.py +++ b/dapr/conf/global_settings.py @@ -34,6 +34,23 @@ DAPR_HTTP_TIMEOUT_SECONDS = 60 +# gRPC keepalive (disabled by default; enable via env to help with idle debugging sessions) +DAPR_GRPC_KEEPALIVE_ENABLED: bool = False +DAPR_GRPC_KEEPALIVE_TIME_MS: int = 120000 # send keepalive pings every 120s +DAPR_GRPC_KEEPALIVE_TIMEOUT_MS: int = ( + 20000 # wait 20s for ack before considering the connection dead +) +DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS: bool = False # allow pings when there are no active calls + +# gRPC retries (disabled by default; enable via env to apply channel service config) +DAPR_GRPC_RETRY_ENABLED: bool = False +DAPR_GRPC_RETRY_MAX_ATTEMPTS: int = 4 +DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS: int = 100 +DAPR_GRPC_RETRY_MAX_BACKOFF_MS: int = 1000 +DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER: float = 2.0 +# Comma-separated list of status codes, e.g., 'UNAVAILABLE,DEADLINE_EXCEEDED' +DAPR_GRPC_RETRY_CODES: str = 'UNAVAILABLE,DEADLINE_EXCEEDED' + # ----- Conversation API settings ------ # Configuration for handling large enums to avoid massive JSON schemas that can exceed LLM token limits diff --git a/dapr/conf/helpers.py b/dapr/conf/helpers.py index d2d18762..266ecb50 100644 --- a/dapr/conf/helpers.py +++ b/dapr/conf/helpers.py @@ -1,6 +1,22 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json from urllib.parse import ParseResult, parse_qs, urlparse from warnings import warn +from dapr.conf import settings + class URIParseConfig: DEFAULT_SCHEME = 'dns' @@ -189,3 +205,68 @@ def _validate_path_and_query(self) -> None: f'query parameters are not supported for gRPC endpoints:' f" '{self._parsed_url.query}'" ) + + +# ------------------------------ +# gRPC channel options helpers +# ------------------------------ + + +def get_grpc_keepalive_options(): + """Return a list of keepalive channel options if enabled, else empty list. + + Options are tuples suitable for passing to grpc.{secure,insecure}_channel. + """ + if not settings.DAPR_GRPC_KEEPALIVE_ENABLED: + return [] + return [ + ('grpc.keepalive_time_ms', int(settings.DAPR_GRPC_KEEPALIVE_TIME_MS)), + ('grpc.keepalive_timeout_ms', int(settings.DAPR_GRPC_KEEPALIVE_TIMEOUT_MS)), + ( + 'grpc.keepalive_permit_without_calls', + 1 if settings.DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS else 0, + ), + ] + + +def get_grpc_retry_service_config_option(): + """Return ('grpc.service_config', json) option if retry is enabled, else None. + + Applies a universal retry policy via gRPC service config. + """ + if not getattr(settings, 'DAPR_GRPC_RETRY_ENABLED', False): + return None + retry_policy = { + 'maxAttempts': int(settings.DAPR_GRPC_RETRY_MAX_ATTEMPTS), + 'initialBackoff': f'{int(settings.DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS) / 1000.0}s', + 'maxBackoff': f'{int(settings.DAPR_GRPC_RETRY_MAX_BACKOFF_MS) / 1000.0}s', + 'backoffMultiplier': float(settings.DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER), + 'retryableStatusCodes': [ + c.strip() for c in str(settings.DAPR_GRPC_RETRY_CODES).split(',') if c.strip() + ], + } + service_config = { + 'methodConfig': [ + { + 'name': [{'service': ''}], # apply to all services + 'retryPolicy': retry_policy, + } + ] + } + return ('grpc.service_config', json.dumps(service_config)) + + +def build_grpc_channel_options(base_options=None): + """Combine base options with keepalive and retry policy options. + + Args: + base_options: optional iterable of (key, value) tuples. + Returns: + list of (key, value) tuples. + """ + options = list(base_options or []) + options.extend(get_grpc_keepalive_options()) + retry_opt = get_grpc_retry_service_config_option() + if retry_opt is not None: + options.append(retry_opt) + return options diff --git a/dev-requirements.txt b/dev-requirements.txt index 828ef8aa..ee12592f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,7 +2,9 @@ mypy>=1.2.0 mypy-extensions>=0.4.3 mypy-protobuf>=2.9 tox>=4.3.0 +pip>=23.0.0 coverage>=5.3 +pytest wheel # used in unit test only opentelemetry-sdk diff --git a/examples/grpc_proxying/invoke-receiver.py b/examples/grpc_proxying/invoke-receiver.py index 0a140ff7..0b31b53a 100644 --- a/examples/grpc_proxying/invoke-receiver.py +++ b/examples/grpc_proxying/invoke-receiver.py @@ -1,3 +1,4 @@ +import json import logging import grpc diff --git a/examples/workflow-async/README.md b/examples/workflow-async/README.md new file mode 100644 index 00000000..4ce670df --- /dev/null +++ b/examples/workflow-async/README.md @@ -0,0 +1,15 @@ +# Dapr Workflow Async Examples (Python) + +These examples mirror `examples/workflow/` but author orchestrators with `async def` using the +async workflow APIs. Activities remain regular functions unless noted. + +How to run: +- Ensure a Dapr sidecar is running locally. If needed, set `DURABLETASK_GRPC_ENDPOINT`, or + `DURABLETASK_GRPC_HOST/PORT`. +- Install requirements: `pip install -r requirements.txt` +- Run any example: `python simple.py` + +Notes: +- Orchestrators use `await ctx.activity(...)`, `await ctx.sleep(...)`, `await ctx.when_all/when_any(...)`, etc. +- No event loop is started manually; the Durable Task worker drives the async orchestrators. +- You can also launch instances using `DaprWorkflowClient` as in the non-async examples. diff --git a/examples/workflow-async/child_workflow.py b/examples/workflow-async/child_workflow.py new file mode 100644 index 00000000..9e89a752 --- /dev/null +++ b/examples/workflow-async/child_workflow.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='child_async') +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +@wfr.async_workflow(name='parent_async') +async def parent(ctx: AsyncWorkflowContext, n: int) -> int: + r = await ctx.call_child_workflow(child, input=n) + print(f'Child workflow returned {r}') + return r + 1 + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'parent_async_instance' + client.schedule_new_workflow(workflow=parent, input=5, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/fan_out_fan_in.py b/examples/workflow-async/fan_out_fan_in.py new file mode 100644 index 00000000..16e6e48d --- /dev/null +++ b/examples/workflow-async/fan_out_fan_in.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='square') +def square(ctx: WorkflowActivityContext, x: int) -> int: + return x * x + + +@wfr.async_workflow(name='fan_out_fan_in_async') +async def orchestrator(ctx: AsyncWorkflowContext): + tasks = [ctx.call_activity(square, input=i) for i in range(1, 6)] + results = await ctx.when_all(tasks) + total = sum(results) + return total + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'fofi_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + wf_state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow state: {wf_state}') + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/human_approval.py b/examples/workflow-async/human_approval.py new file mode 100644 index 00000000..7ce17722 --- /dev/null +++ b/examples/workflow-async/human_approval.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, DaprWorkflowClient, WorkflowRuntime + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name='human_approval_async') +async def orchestrator(ctx: AsyncWorkflowContext, request_id: str): + decision = await ctx.when_any( + [ + ctx.wait_for_external_event(f'approve:{request_id}'), + ctx.wait_for_external_event(f'reject:{request_id}'), + ctx.create_timer(300.0), + ] + ) + if isinstance(decision, dict) and decision.get('approved'): + return 'APPROVED' + if isinstance(decision, dict) and decision.get('rejected'): + return 'REJECTED' + return 'TIMEOUT' + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'human_approval_async_1' + client.schedule_new_workflow(workflow=orchestrator, input='REQ-1', instance_id=instance_id) + # In a real scenario, raise approve/reject event from another service. + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/requirements.txt b/examples/workflow-async/requirements.txt new file mode 100644 index 00000000..e220036d --- /dev/null +++ b/examples/workflow-async/requirements.txt @@ -0,0 +1,2 @@ +dapr-ext-workflow-dev>=1.15.0.dev +dapr-dev>=1.15.0.dev diff --git a/examples/workflow-async/simple.py b/examples/workflow-async/simple.py new file mode 100644 index 00000000..1e7cf039 --- /dev/null +++ b/examples/workflow-async/simple.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from datetime import timedelta +from time import sleep + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + RetryPolicy, + WorkflowActivityContext, + WorkflowRuntime, +) + +counter = 0 +retry_count = 0 +child_orchestrator_string = '' +instance_id = 'asyncExampleInstanceID' +child_instance_id = 'asyncChildInstanceID' +workflow_name = 'async_hello_world_wf' +child_workflow_name = 'async_child_wf' +input_data = 'Hi Async Counter!' +event_name = 'event1' +event_data = 'eventData' + +retry_policy = RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=100), +) + +wfr = WorkflowRuntime() + + +@wfr.async_workflow(name=workflow_name) +async def hello_world_wf(ctx: AsyncWorkflowContext, wf_input): + # activities + result_1 = await ctx.call_activity(hello_act, input=1) + print(f'Activity 1 returned {result_1}') + result_2 = await ctx.call_activity(hello_act, input=10) + print(f'Activity 2 returned {result_2}') + result_3 = await ctx.call_activity(hello_retryable_act, retry_policy=retry_policy) + print(f'Activity 3 returned {result_3}') + result_4 = await ctx.call_child_workflow(child_retryable_wf, retry_policy=retry_policy) + print(f'Child workflow returned {result_4}') + + # Event vs timeout using when_any + first = await ctx.when_any( + [ + ctx.wait_for_external_event(event_name), + ctx.create_timer(timedelta(seconds=30)), + ] + ) + + # Proceed only if event won + if isinstance(first, dict) and 'event' in first: + await ctx.call_activity(hello_act, input=100) + await ctx.call_activity(hello_act, input=1000) + return 'Completed' + return 'Timeout' + + +@wfr.activity(name='async_hello_act') +def hello_act(ctx: WorkflowActivityContext, wf_input): + global counter + counter += wf_input + return f'Activity returned {wf_input}' + + +@wfr.activity(name='async_hello_retryable_act') +def hello_retryable_act(ctx: WorkflowActivityContext): + global retry_count + if (retry_count % 2) == 0: + retry_count += 1 + raise ValueError('Retryable Error') + retry_count += 1 + return f'Activity returned {retry_count}' + + +@wfr.async_workflow(name=child_workflow_name) +async def child_retryable_wf(ctx: AsyncWorkflowContext): + # Call activity with retry and simulate retryable workflow failure until certain state + child_activity_result = await ctx.call_activity( + act_for_child_wf, input='x', retry_policy=retry_policy + ) + print(f'Child activity returned {child_activity_result}') + # In a real sample, you might check state and raise to trigger retry + return 'ok' + + +@wfr.activity(name='async_act_for_child_wf') +def act_for_child_wf(ctx: WorkflowActivityContext, inp): + global child_orchestrator_string + child_orchestrator_string += inp + + +def main(): + wfr.start() + wf_client = DaprWorkflowClient() + + wf_client.schedule_new_workflow( + workflow=hello_world_wf, input=input_data, instance_id=instance_id + ) + + wf_client.wait_for_workflow_start(instance_id) + + # Let initial activities run + sleep(5) + + # Raise event to continue + wf_client.raise_workflow_event( + instance_id=instance_id, event_name=event_name, data={'ok': True} + ) + + # Wait for completion + state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print(f'Workflow status: {state.runtime_status.name}') + + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow-async/task_chaining.py b/examples/workflow-async/task_chaining.py new file mode 100644 index 00000000..c9c92add --- /dev/null +++ b/examples/workflow-async/task_chaining.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + WorkflowActivityContext, + WorkflowRuntime, +) + +wfr = WorkflowRuntime() + + +@wfr.activity(name='sum') +def sum_act(ctx: WorkflowActivityContext, nums): + return sum(nums) + + +@wfr.async_workflow(name='task_chaining_async') +async def orchestrator(ctx: AsyncWorkflowContext): + a = await ctx.call_activity(sum_act, input=[1, 2]) + b = await ctx.call_activity(sum_act, input=[a, 3]) + c = await ctx.call_activity(sum_act, input=[b, 4]) + return c + + +def main(): + wfr.start() + client = DaprWorkflowClient() + instance_id = 'task_chain_async' + client.schedule_new_workflow(workflow=orchestrator, instance_id=instance_id) + client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/README.md b/examples/workflow/README.md index 2e09eeef..3cb71f0d 100644 --- a/examples/workflow/README.md +++ b/examples/workflow/README.md @@ -12,6 +12,8 @@ This directory contains examples of using the [Dapr Workflow](https://docs.dapr. You can install dapr SDK package using pip command: ```sh +python3 -m venv .venv +source .venv/bin/activate pip3 install -r requirements.txt ``` diff --git a/examples/workflow/aio/async_activity_sequence.py b/examples/workflow/aio/async_activity_sequence.py new file mode 100644 index 00000000..8eecd1f8 --- /dev/null +++ b/examples/workflow/aio/async_activity_sequence.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.activity(name='add') + def add(ctx, xy): + return xy[0] + xy[1] + + @rt.workflow(name='sum_three') + async def sum_three(ctx: AsyncWorkflowContext, nums): + a = await ctx.call_activity(add, input=[nums[0], nums[1]]) + b = await ctx.call_activity(add, input=[a, nums[2]]) + return b + + rt.start() + print("Registered async workflow 'sum_three' and activity 'add'") + + # This example registers only; use Dapr client to start instances externally. + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_external_event.py b/examples/workflow/aio/async_external_event.py new file mode 100644 index 00000000..90531422 --- /dev/null +++ b/examples/workflow/aio/async_external_event.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='wait_event') + async def wait_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return {'event': data} + + rt.start() + print("Registered async workflow 'wait_event'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/async_sub_orchestrator.py b/examples/workflow/aio/async_sub_orchestrator.py new file mode 100644 index 00000000..c00d9ca9 --- /dev/null +++ b/examples/workflow/aio/async_sub_orchestrator.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow import AsyncWorkflowContext, WorkflowRuntime + + +def main(): + rt = WorkflowRuntime() + + @rt.async_workflow(name='child') + async def child(ctx: AsyncWorkflowContext, n): + return n * 2 + + @rt.async_workflow(name='parent') + async def parent(ctx: AsyncWorkflowContext, n): + r = await ctx.call_child_workflow(child, input=n) + return r + 1 + + rt.start() + print("Registered async workflows 'parent' and 'child'") + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/aio/context_interceptors_example.py b/examples/workflow/aio/context_interceptors_example.py new file mode 100644 index 00000000..d005bca1 --- /dev/null +++ b/examples/workflow/aio/context_interceptors_example.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- + +""" +Example: Interceptors for context propagation (client + runtime). + +This example shows how to: + - Define a small context (dict) carried via contextvars + - Implement ClientInterceptor to inject that context into outbound inputs + - Implement RuntimeInterceptor to restore the context before user code runs + - Wire interceptors into WorkflowRuntime and DaprWorkflowClient + +Note: Scheduling/running requires a Dapr sidecar. This file focuses on the wiring pattern. +""" + +from __future__ import annotations + +import contextvars +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +# A simple context carried across boundaries +_current_ctx: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + 'wf_ctx', default=None +) + + +def set_ctx(ctx: dict[str, Any] | None) -> None: + _current_ctx.set(ctx) + + +def get_ctx() -> dict[str, Any] | None: + return _current_ctx.get() + + +def _merge_ctx(args: Any) -> Any: + ctx = get_ctx() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + +class ContextClientInterceptor(BaseClientInterceptor): + def schedule_new_workflow( + self, input: ScheduleWorkflowInput, nxt: Callable[[ScheduleWorkflowInput], Any] + ) -> Any: # type: ignore[override] + input = ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + ) + return nxt(input) + + +class ContextWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def call_child_workflow( + self, input: CallChildWorkflowInput, nxt: Callable[[CallChildWorkflowInput], Any] + ) -> Any: + return nxt( + CallChildWorkflowInput( + workflow_name=input.workflow_name, + args=_merge_ctx(input.args), + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + def call_activity( + self, input: CallActivityInput, nxt: Callable[[CallActivityInput], Any] + ) -> Any: + return nxt( + CallActivityInput( + activity_name=input.activity_name, + args=_merge_ctx(input.args), + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + local_context=input.local_context, + ) + ) + + +class ContextRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow( + self, input: ExecuteWorkflowInput, nxt: Callable[[ExecuteWorkflowInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + def execute_activity( + self, input: ExecuteActivityInput, nxt: Callable[[ExecuteActivityInput], Any] + ) -> Any: # type: ignore[override] + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + +# Example workflow and activity +def activity_log(ctx, data: dict[str, Any]) -> str: # noqa: ANN001 (example) + # Access restored context inside activity via contextvars + return f'ok:{get_ctx()}' + + +def workflow_example(ctx, x: int): # noqa: ANN001 (example) + y = yield ctx.call_activity(activity_log, input={'msg': 'hello'}) + return y + + +def wire_up() -> tuple[WorkflowRuntime, DaprWorkflowClient]: + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + + # Register workflow/activity + runtime.workflow(name='example')(workflow_example) + runtime.activity(name='activity_log')(activity_log) + return runtime, client + + +if __name__ == '__main__': + # This section demonstrates how you would set a context and schedule a workflow. + # Requires a running Dapr sidecar to actually execute. + rt, cli = wire_up() + set_ctx({'tenant': 'acme', 'request_id': 'r-123'}) + # instance_id = cli.schedule_new_workflow(workflow_example, input={'x': 1}) + # print('scheduled:', instance_id) + # rt.start(); rt.wait_for_ready(); ... + pass diff --git a/examples/workflow/aio/model_tool_serialization_example.py b/examples/workflow/aio/model_tool_serialization_example.py new file mode 100644 index 00000000..2c1bdf4c --- /dev/null +++ b/examples/workflow/aio/model_tool_serialization_example.py @@ -0,0 +1,66 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from dapr.ext.workflow import ensure_canonical_json + +""" +Example of implementing provider-specific model/tool serialization OUTSIDE the core package. + +This demonstrates how to build and use your own contracts using the generic helpers from +`dapr.ext.workflow.serializers`. +""" + + +def to_model_request(payload: Dict[str, Any]) -> Dict[str, Any]: + req = { + 'schema_version': 'model_req@v1', + 'model_name': payload.get('model_name'), + 'system_instructions': payload.get('system_instructions'), + 'input': payload.get('input'), + 'model_settings': payload.get('model_settings') or {}, + 'tools': payload.get('tools') or [], + } + return ensure_canonical_json(req, strict=True) + + +def from_model_response(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict): + content = obj.get('content') + tool_calls = obj.get('tool_calls') or [] + out = {'schema_version': 'model_res@v1', 'content': content, 'tool_calls': tool_calls} + return ensure_canonical_json(out, strict=False) + return ensure_canonical_json( + {'schema_version': 'model_res@v1', 'content': str(obj), 'tool_calls': []}, strict=False + ) + + +def to_tool_request(name: str, args: list | None, kwargs: dict | None) -> Dict[str, Any]: + req = { + 'schema_version': 'tool_req@v1', + 'tool_name': name, + 'args': args or [], + 'kwargs': kwargs or {}, + } + return ensure_canonical_json(req, strict=True) + + +def from_tool_result(obj: Any) -> Dict[str, Any]: + if isinstance(obj, dict) and ('result' in obj or 'error' in obj): + return ensure_canonical_json({'schema_version': 'tool_res@v1', **obj}, strict=False) + return ensure_canonical_json( + {'schema_version': 'tool_res@v1', 'result': obj, 'error': None}, strict=False + ) diff --git a/examples/workflow/aio/tracing_interceptors_example.py b/examples/workflow/aio/tracing_interceptors_example.py new file mode 100644 index 00000000..ea4834fb --- /dev/null +++ b/examples/workflow/aio/tracing_interceptors_example.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Callable + +from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityInput, + CallChildWorkflowInput, + DaprWorkflowClient, + ExecuteActivityInput, + ExecuteWorkflowInput, + ScheduleWorkflowInput, + WorkflowRuntime, +) + +TRACE_ID_KEY = 'otel.trace_id' +SPAN_ID_KEY = 'otel.span_id' + + +class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def schedule_new_workflow(self, input: ScheduleWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict(input.metadata or {}) + md[TRACE_ID_KEY] = trace_id + md[SPAN_ID_KEY] = span_id + return next( + ScheduleWorkflowInput( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=md, + local_context=input.local_context, + ) + ) + + +class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def __init__(self, on_span: Callable[[str, dict[str, str]], Any]): + self._on_span = on_span + + def execute_workflow(self, input: ExecuteWorkflowInput, next): # type: ignore[override] + # Suppress spans during replay + if not input.ctx.is_replaying: + self._on_span('dapr:executeWorkflow', input.metadata or {}) + return next(input) + + def execute_activity(self, input: ExecuteActivityInput, next): # type: ignore[override] + self._on_span('dapr:executeActivity', input.metadata or {}) + return next(input) + + +class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_current_trace: Callable[[], tuple[str, str]]): + self._get = get_current_trace + + def call_activity(self, input: CallActivityInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + activity_name=input.activity_name, + args=input.args, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + def call_child_workflow(self, input: CallChildWorkflowInput, next): # type: ignore[override] + trace_id, span_id = self._get() + md = dict((input.metadata or {}) or {}) + md[TRACE_ID_KEY] = md.get(TRACE_ID_KEY, trace_id) + md[SPAN_ID_KEY] = span_id + return next( + type(input)( + workflow_name=input.workflow_name, + args=input.args, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + local_context=input.local_context, + ) + ) + + +def example_usage(): + # Simplified trace getter and span recorder + def _get_trace(): + return ('trace-123', 'span-abc') + + spans: list[tuple[str, dict[str, str]]] = [] + + def _on_span(name: str, attrs: dict[str, str]): + spans.append((name, attrs)) + + runtime = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor(_on_span)], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(_get_trace)], + ) + + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(_get_trace)]) + + # Register and run as you would normally; spans list can be asserted in tests + return runtime, client, spans + + +if __name__ == '__main__': # pragma: no cover + example_usage() diff --git a/examples/workflow/e2e_execinfo.py b/examples/workflow/e2e_execinfo.py new file mode 100644 index 00000000..91f0b295 --- /dev/null +++ b/examples/workflow/e2e_execinfo.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import time + +from dapr.ext.workflow import DaprWorkflowClient, WorkflowRuntime + + +def main(): + port = '50001' + + rt = WorkflowRuntime(port=port) + + def activity_noop(ctx): + ei = ctx.execution_info + # Return attempt (may be None if engine doesn't set it) + return { + 'attempt': ei.attempt if ei else None, + 'workflow_id': ei.workflow_id if ei else None, + } + + @rt.workflow(name='child-to-parent') + def child(ctx, x): + ei = ctx.execution_info + out = yield ctx.call_activity(activity_noop, input=None) + return { + 'child_workflow_name': ei.workflow_name if ei else None, + 'parent_instance_id': ei.parent_instance_id if ei else None, + 'activity': out, + } + + @rt.workflow(name='parent') + def parent(ctx, x): + res = yield ctx.call_child_workflow(child, input={'x': x}) + return res + + rt.register_activity(activity_noop, name='activity_noop') + + rt.start() + try: + # Wait for the worker to be ready to accept work + rt.wait_for_ready(timeout=10) + + client = DaprWorkflowClient(port=port) + instance_id = client.schedule_new_workflow(parent, input=1) + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=30) + print('instance:', instance_id) + print('runtime_status:', state.runtime_status if state else None) + print('state:', state) + finally: + # Give a moment for logs to flush then shutdown + time.sleep(0.5) + rt.shutdown() + + +if __name__ == '__main__': + main() diff --git a/examples/workflow/requirements.txt b/examples/workflow/requirements.txt index c5af70b9..367e80be 100644 --- a/examples/workflow/requirements.txt +++ b/examples/workflow/requirements.txt @@ -1,2 +1,12 @@ -dapr-ext-workflow>=1.16.0.dev -dapr>=1.16.0.dev +# dapr-ext-workflow-dev>=1.16.0.dev +# dapr-dev>=1.16.0.dev + +# local development: install local packages in editable mode + +# if using dev version of durabletask-python +-e ../../../durabletask-python + +# if using dev version of dapr-ext-workflow +-e ../../ext/dapr-ext-workflow +-e ../.. + diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py index 5324c617..bf3ad7b5 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py @@ -13,11 +13,10 @@ limitations under the License. """ -from dapr.ext.grpc.app import App, Rule # type:ignore - from dapr.clients.grpc._jobs import ConstantFailurePolicy, DropFailurePolicy, FailurePolicy, Job from dapr.clients.grpc._request import BindingRequest, InvokeMethodRequest, JobEvent from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse +from dapr.ext.grpc.app import App, Rule # type:ignore __all__ = [ 'App', diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py index 58e0cdf2..87543bdf 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py @@ -16,11 +16,10 @@ from concurrent import futures from typing import Dict, Optional -from dapr.ext.grpc._health_servicer import _HealthCheckServicer # type: ignore -from dapr.ext.grpc._servicer import Rule, _CallbackServicer # type: ignore - import grpc from dapr.conf import settings +from dapr.ext.grpc._health_servicer import _HealthCheckServicer # type: ignore +from dapr.ext.grpc._servicer import Rule, _CallbackServicer # type: ignore from dapr.proto import appcallback_service_v1 diff --git a/ext/dapr-ext-workflow/README.rst b/ext/dapr-ext-workflow/README.rst index aa0003c6..1a99ad39 100644 --- a/ext/dapr-ext-workflow/README.rst +++ b/ext/dapr-ext-workflow/README.rst @@ -16,6 +16,488 @@ Installation pip install dapr-ext-workflow +Async authoring (experimental) +------------------------------ + +This package supports authoring workflows with ``async def`` in addition to the existing generator-based orchestrators. + +- Register async workflows using ``WorkflowRuntime.workflow`` (auto-detects coroutine) or ``async_workflow`` / ``register_async_workflow``. +- Use ``AsyncWorkflowContext`` for deterministic operations: + + - Activities: ``await ctx.call_activity(activity_fn, input=...)`` + - Child workflows: ``await ctx.call_child_workflow(workflow_fn, input=...)`` + - Timers: ``await ctx.create_timer(seconds|timedelta)`` + - External events: ``await ctx.wait_for_external_event(name)`` + - Concurrency: ``await ctx.when_all([...])``, ``await ctx.when_any([...])`` + - Deterministic utils: ``ctx.now()``, ``ctx.random()``, ``ctx.uuid4()``, ``ctx.new_guid()``, ``ctx.random_string(length)`` + +Interceptors (client/runtime/outbound) +-------------------------------------- + +Interceptors provide a simple, composable way to apply cross-cutting behavior with a single +enter/exit per call. There are three types: + +- Client interceptors: wrap outbound scheduling from the client (schedule_new_workflow). +- Workflow outbound interceptors: wrap calls made inside workflows (call_activity, call_child_workflow). +- Runtime interceptors: wrap inbound execution of workflows and activities (before user code). + +Use cases include context propagation, request metadata stamping, replay-aware logging, validation, +and policy enforcement. + +Response/output shaping +~~~~~~~~~~~~~~~~~~~~~~~ + +Interceptors are "around" hooks: they can shape inputs before calling ``next(...)`` and may also +shape the returned value (or map exceptions) after ``next(...)`` returns. This mirrors gRPC +interceptors and keeps the surface simple – one hook per interception point. + +- Client interceptors can transform schedule/query/signal responses. +- Runtime interceptors can transform workflow/activity results (with guardrails below). +- Workflow-outbound interceptors remain input-only to keep awaitable composition simple. + +Examples +^^^^^^^^ + +Client schedule response shaping:: + + from dapr.ext.workflow import ( + DaprWorkflowClient, ClientInterceptor, ScheduleWorkflowRequest + ) + + class ShapeId(ClientInterceptor): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + raw = next(input) + return f"tenant-A:{raw}" + + client = DaprWorkflowClient(interceptors=[ShapeId()]) + instance_id = client.schedule_new_workflow(my_workflow, input={}) + # instance_id == "tenant-A:" + +Runtime activity result shaping:: + + from dapr.ext.workflow import WorkflowRuntime, RuntimeInterceptor, ExecuteActivityRequest + + class WrapResult(RuntimeInterceptor): + def execute_activity(self, input: ExecuteActivityRequest, next): + res = next(input) + return {"value": res} + + rt = WorkflowRuntime(runtime_interceptors=[WrapResult()]) + @rt.activity + def echo(ctx, x): + return x + # echo(...) returns {"value": x} + +Determinism guardrails (workflows) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Workflow response shaping must be replay-safe: pure transforms only (no I/O, time, RNG). +- Base the transform solely on (input, metadata, original_result). Map errors to typed exceptions. +- Activities are not replayed, so result shaping may perform I/O, but keep it lightweight. + +Quick start +~~~~~~~~~~~ + +.. code-block:: python + + from __future__ import annotations + import contextvars + from typing import Any, Callable, List + + from dapr.ext.workflow import ( + WorkflowRuntime, + DaprWorkflowClient, + ClientInterceptor, + WorkflowOutboundInterceptor, + RuntimeInterceptor, + ScheduleWorkflowRequest, + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteWorkflowRequest, + ExecuteActivityRequest, + ) + + # Example: propagate a lightweight context dict through inputs + _current_ctx: contextvars.ContextVar[Optional[dict[str, Any]]] = contextvars.ContextVar( + 'wf_ctx', default=None + ) + + def set_ctx(ctx: Optional[dict[str, Any]]): + _current_ctx.set(ctx) + + def _merge_ctx(args: Any) -> Any: + ctx = _current_ctx.get() + if ctx and isinstance(args, dict) and 'context' not in args: + return {**args, 'context': ctx} + return args + + # Typed payloads + class MyWorkflowInput: + def __init__(self, question: str, tags: List[str] | None = None): + self.question = question + self.tags = tags or [] + + class MyActivityInput: + def __init__(self, name: str, count: int): + self.name = name + self.count = count + + class ContextClientInterceptor(ClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow(self, input: ScheduleWorkflowRequest[MyWorkflowInput], nxt: Callable[[ScheduleWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + input = ScheduleWorkflowRequest( + workflow_name=input.workflow_name, + input=_merge_ctx(input.input), + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + ) + return nxt(input) + + class ContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow(self, input: CallChildWorkflowRequest[MyWorkflowInput], nxt: Callable[[CallChildWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + return nxt(CallChildWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=_merge_ctx(input.input), + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + )) + + def call_activity(self, input: CallActivityRequest[MyActivityInput], nxt: Callable[[CallActivityRequest[MyActivityInput]], Any]) -> Any: + return nxt(CallActivityRequest[MyActivityInput]( + activity_name=input.activity_name, + input=_merge_ctx(input.input), + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + )) + + class ContextRuntimeInterceptor(RuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow(self, input: ExecuteWorkflowRequest[MyWorkflowInput], nxt: Callable[[ExecuteWorkflowRequest[MyWorkflowInput]], Any]) -> Any: + # Restore context from input if present (no I/O, replay-safe) + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + def execute_activity(self, input: ExecuteActivityRequest[MyActivityInput], nxt: Callable[[ExecuteActivityRequest[MyActivityInput]], Any]) -> Any: + if isinstance(input.input, dict) and 'context' in input.input: + set_ctx(input.input['context']) + try: + return nxt(input) + finally: + set_ctx(None) + + # Wire into client and runtime + runtime = WorkflowRuntime( + runtime_interceptors=[ContextRuntimeInterceptor()], + workflow_outbound_interceptors=[ContextWorkflowOutboundInterceptor()], + ) + + client = DaprWorkflowClient(interceptors=[ContextClientInterceptor()]) + +Context metadata (durable propagation) +------------------------------------- + +Interceptors support a durable context channel: + +- ``metadata``: a string-only dict that is durably persisted and propagated across workflow + boundaries (schedule, child workflows, activities). Typical use: tracing and correlation ids + (e.g., ``otel.trace_id``), tenancy, request ids. This is provider-agnostic and does not require + changes to your workflow/activities. + +How it works +~~~~~~~~~~~~ + +- Client interceptors can set ``metadata`` when scheduling a workflow or calling activities/children. +- Runtime unwraps a reserved envelope before user code runs and exposes the metadata to + ``RuntimeInterceptor`` via ``ExecuteWorkflowRequest.metadata`` / ``ExecuteActivityRequest.metadata``, + while delivering only the original payload to the user function. +- Outbound calls made inside a workflow use client interceptors; when ``metadata`` is present on the + call input, the runtime re-wraps the payload to persist and propagate it. + +Envelope (backward compatible) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Internally, the runtime persists metadata by wrapping inputs in an envelope: + +:: + + { + "__dapr_meta__": { "v": 1, "metadata": { "otel.trace_id": "abc" } }, + "__dapr_payload__": { ... original user input ... } + } + +- The runtime unwraps this automatically so user code continues to receive the exact original input + structure and types. +- The version field (``v``) is reserved for forward compatibility. + +Minimal input guidance (SDK-facing) +----------------------------------- + +- Workflow input SHOULD be JSON serializable and a preferably a single dict carried under ``ExecuteWorkflowRequest.input``. Prefer a + single object over positional ``input`` to avoid shape ambiguity and ease future evolution. This is + a recommendation for consistency and versioning; the SDK accepts any JSON-serializable input type + (dict, list, or scalar) and preserves the original shape when unwrapping the envelope. + +- For contextual data, you can use "headers" (aliases for metadata) on the workflow context: + ``set_headers``/``get_headers`` behave the same as ``set_metadata``/``get_metadata`` and are + provided for familiarity with systems that use header terminology. ``continue_as_new`` also + supports ``carryover_headers`` as an alias to ``carryover_metadata``. +- If your app needs a tracing or correlation fallback, include a small ``trace_context`` dict in + your input envelope. Interceptors should restore from ``metadata`` first (see below), then + optionally fall back to this field when present. + +Example (generic): + +.. code-block:: json + + { + "schema_version": "your-app:workflow_input@v1", + "trace_context": { "trace_id": "...", "span_id": "..." }, + "payload": { } + } + +Determinism and safety +~~~~~~~~~~~~~~~~~~~~~~ + +- In workflows, read metadata and avoid non-deterministic operations inside interceptors. Do not + perform network I/O in orchestrators. +- Activities may read/modify metadata and perform I/O inside the activity function if desired. + +Metadata persistence lifecycle +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``ctx.set_metadata()`` attaches a string-only dict to the current workflow activation. The runtime + persists it by wrapping inputs in the envelope shown above. Set metadata before yielding or + returning from an activation to ensure it is durably recorded. +- ``continue_as_new``: metadata is not implicitly carried. Use + ``ctx.continue_as_new(new_input, carryover_metadata=True)`` to carry current metadata or provide a + dict to merge/override: ``carryover_metadata={"key": "value"}``. +- Child workflows and activities: metadata is propagated when set on the outbound call input by + interceptors. If you maintain a baseline via ``ctx.set_metadata(...)``, your + ``WorkflowOutboundInterceptor`` can merge it into call-specific metadata. + +Tracing interceptors (example) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can implement tracing as interceptors that stamp/propagate IDs in ``metadata`` and suppress +spans during replay. A minimal sketch: + +.. code-block:: python + + from typing import Any, Callable + from dapr.ext.workflow import ( + BaseClientInterceptor, BaseWorkflowOutboundInterceptor, BaseRuntimeInterceptor, + WorkflowRuntime, DaprWorkflowClient, + ScheduleWorkflowRequest, CallActivityRequest, CallChildWorkflowRequest, + ExecuteWorkflowRequest, ExecuteActivityRequest, + ) + + TRACE_ID_KEY = 'otel.trace_id' + + class TracingClientInterceptor(BaseClientInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def schedule_new_workflow(self, input: ScheduleWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(ScheduleWorkflowRequest( + workflow_name=input.workflow_name, + input=input.input, + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=md, + )) + + class TracingWorkflowOutboundInterceptor(BaseWorkflowOutboundInterceptor): + def __init__(self, get_trace: Callable[[], str]): + self._get = get_trace + def call_activity(self, input: CallActivityRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + activity_name=input.activity_name, + input=input.input, + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + def call_child_workflow(self, input: CallChildWorkflowRequest, next): + md = dict(input.metadata or {}) + md.setdefault(TRACE_ID_KEY, self._get()) + return next(type(input)( + workflow_name=input.workflow_name, + input=input.input, + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=md, + )) + + class TracingRuntimeInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + if not input.ctx.is_replaying: + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start workflow span here + return next(input) + def execute_activity(self, input: ExecuteActivityRequest, next): + _trace_id = (input.metadata or {}).get(TRACE_ID_KEY) + # start activity span here + return next(input) + + rt = WorkflowRuntime( + runtime_interceptors=[TracingRuntimeInterceptor()], + workflow_outbound_interceptors=[TracingWorkflowOutboundInterceptor(lambda: 'trace-123')], + ) + client = DaprWorkflowClient(interceptors=[TracingClientInterceptor(lambda: 'trace-123')]) + +See the full runnable example in ``ext/dapr-ext-workflow/examples/tracing_interceptors_example.py``. + +Recommended tracing restoration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Restore tracing from ``ExecuteWorkflowRequest.metadata`` first (e.g., a key like ``otel.trace_id``) + to preserve determinism and cross-activation continuity without touching user payloads. +- If no tracing metadata is present, optionally fall back to ``input.trace_context`` in your + application-defined input envelope. +- Suppress workflow spans during replay by checking ``input.ctx.is_replaying`` in runtime + interceptors. + +Engine-provided tracing +~~~~~~~~~~~~~~~~~~~~~~~ + +- When available from the runtime, use engine-provided fields surfaced on the contexts instead of + reconstructing from headers/metadata: + + - ``ctx.trace_parent`` / ``ctx.trace_state`` (and the same on ``activity_ctx``) + - ``ctx.workflow_span_id`` (identifier for the workflow span) + +- Interceptors should prefer these fields. Use headers/metadata only as a fallback or for + application-specific context. + +Execution info (minimal) and context properties +----------------------------------------------- + +``execution_info`` is now minimal and only includes the durable ``inbound_metadata`` that was +propagated into this activation. Use context properties directly for all engine fields: + +- ``ctx.trace_parent``, ``ctx.workflow_span_id``, ``ctx.workflow_attempt`` on workflow contexts. +- Manage outbound propagation via ``ctx.set_metadata(...)`` / ``ctx.get_metadata()``. The runtime + persists and propagates these values through the metadata envelope. + +Example: + +.. code-block:: python + + # In a workflow function + inbound = ctx.execution_info.inbound_metadata if ctx.execution_info else None + # Prepare outbound propagation + baseline = ctx.get_metadata() or {} + ctx.set_metadata({**baseline, 'tenant': 'acme'}) + +Notes +~~~~~ + +- User functions never see the envelope keys; they get the same input as before. +- Only string keys/values should be stored in headers/metadata; enforce size limits and redaction + policies as needed. +- With newer durabletask-python, the engine provides deterministic context fields directly on + ``OrchestrationContext`` (accessible via ``ctx.workflow_name``, ``ctx.parent_instance_id``, + and ``ctx.history_event_sequence``). Note that ``ctx.execution_info`` only contains + ``inbound_metadata``; use context properties directly for engine fields. +- Interceptors are synchronous and must not perform I/O in orchestrators. Activities may perform + I/O inside the user function; interceptor code should remain fast and replay-safe. +- Client interceptors are applied when calling ``DaprWorkflowClient.schedule_new_workflow(...)`` and + when orchestrators call ``ctx.call_activity(...)`` or ``ctx.call_child_workflow(...)``. + + +Best-effort sandbox +~~~~~~~~~~~~~~~~~~~ + +Opt-in scoped compatibility mode maps ``asyncio.sleep``, ``random``, ``uuid.uuid4``, and ``time.time`` to deterministic equivalents during workflow execution. Use ``sandbox_mode="best_effort"`` or ``"strict"`` when registering async workflows. Strict mode blocks ``asyncio.create_task`` in orchestrators. + +Examples +~~~~~~~~ + +See ``ext/dapr-ext-workflow/examples`` for: + +- ``async_activity_sequence.py`` +- ``async_external_event.py`` +- ``async_sub_orchestrator.py`` + +Determinism and semantics +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``when_any`` losers: the first-completer result is returned; non-winning awaitables are ignored deterministically (no additional commands are emitted by the orchestrator for cancellation). This ensures replay stability. Integration behavior with the sidecar is subject to the Durable Task scheduler; the orchestrator does not actively cancel losers. +- Suspension and termination: when an instance is suspended, only new external events are buffered while replay continues to reconstruct state; async orchestrators can inspect ``ctx.is_suspended`` if exposed by the runtime. Termination completes the orchestrator with TERMINATED status and does not raise into the coroutine. End-to-end confirmation requires running against a sidecar; unit tests in this repo do not start a sidecar. + +Async patterns +~~~~~~~~~~~~~~ + +- Activities + + - Call: ``await ctx.call_activity(activity_fn, input=..., retry_policy=...)`` + - Activity functions can be ``def`` or ``async def``. When ``async def`` is used, the runtime awaits them. + +- Timers + + - Create a durable timer: ``await ctx.create_timer(seconds|timedelta)`` + +- External events + + - Wait: ``await ctx.wait_for_external_event(name)`` + - Raise (from client): ``DaprWorkflowClient.raise_workflow_event(instance_id, name, data)`` + +- Concurrency + + - All: ``results = await ctx.when_all([ ...awaitables... ])`` + - Any: ``first = await ctx.when_any([ ...awaitables... ])`` (non-winning awaitables are ignored deterministically) + +- Child workflows + + - Call: ``await ctx.call_child_workflow(workflow_fn, input=..., retry_policy=...)`` + +- Deterministic utilities + + - ``ctx.now()`` returns orchestration time from history + - ``ctx.random()`` returns a deterministic PRNG + - ``ctx.uuid4()`` returns a PRNG-derived deterministic UUID + +Runtime compatibility +--------------------- + +- ``ctx.is_suspended`` is surfaced if provided by the underlying runtime/context version; behavior may vary by Durable Task build. Integration tests that validate suspension semantics are gated behind a sidecar harness. + +when_any losers diagnostics (integration) +----------------------------------------- + +- When the sidecar exposes command diagnostics, you can assert only a single command set is emitted for a ``when_any`` (the orchestrator completes after the first winner without emitting cancels). Until then, unit tests assert single-yield behavior and README documents the expected semantics. + +Micro-bench guidance +-------------------- + +- The coroutine-to-generator driver yields at each deterministic suspension point and avoids polling. In practice, overhead vs. generator orchestrators is negligible relative to activity I/O. To measure locally: + + - Create paired generator/async orchestrators that call N no-op activities and 1 timer. + - Drive them against a local sidecar and compare wall-clock per activation and total completion time. + - Ensure identical history/inputs; differences should be within noise vs. activity latency. + +Notes +----- + +- Orchestrators authored as ``async def`` are not driven by a global event loop you start. The Durable Task worker drives them via a coroutine-to-generator bridge; do not call ``asyncio.run`` around orchestrators. +- Use ``WorkflowRuntime.workflow`` with an ``async def`` (auto-detected) or ``WorkflowRuntime.async_workflow`` to register async orchestrators. + +Why async without an event loop? +-------------------------------- + +- Each ``await`` in an async orchestrator corresponds to a deterministic Durable Task decision (activity, timer, external event, ``when_all/any``). The worker advances the coroutine by sending results/exceptions back in, preserving replay and ordering. +- This gives you the readability and structure of ``async/await`` while enforcing workflow determinism (no ad-hoc I/O in orchestrators; all I/O happens in activities). +- The pattern follows other workflow engines (e.g., Durable Functions/Temporal): async authoring for clarity, runtime-driven scheduling for correctness. + References ---------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py index dd2d45b7..b6a75e47 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py @@ -14,9 +14,39 @@ """ # Import your main classes here +from dapr.ext.workflow.aio import AsyncWorkflowContext from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, when_all, when_any +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ClientInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, +) from dapr.ext.workflow.retry_policy import RetryPolicy +from dapr.ext.workflow.serializers import ( + ActivityIOAdapter, + CanonicalSerializable, + GenericSerializer, + ensure_canonical_json, + get_activity_adapter, + get_serializer, + register_activity_adapter, + register_serializer, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name from dapr.ext.workflow.workflow_state import WorkflowState, WorkflowStatus @@ -25,6 +55,7 @@ 'WorkflowRuntime', 'DaprWorkflowClient', 'DaprWorkflowContext', + 'AsyncWorkflowContext', 'WorkflowActivityContext', 'WorkflowState', 'WorkflowStatus', @@ -32,4 +63,32 @@ 'when_any', 'alternate_name', 'RetryPolicy', + # interceptors + 'ClientInterceptor', + 'BaseClientInterceptor', + 'WorkflowOutboundInterceptor', + 'BaseWorkflowOutboundInterceptor', + 'RuntimeInterceptor', + 'BaseRuntimeInterceptor', + 'ScheduleWorkflowRequest', + 'CallChildWorkflowRequest', + 'CallActivityRequest', + 'ExecuteWorkflowRequest', + 'ExecuteActivityRequest', + 'compose_workflow_outbound_chain', + 'compose_runtime_chain', + 'WorkflowExecutionInfo', + 'ActivityExecutionInfo', + # serializers + 'CanonicalSerializable', + 'GenericSerializer', + 'ActivityIOAdapter', + 'ensure_canonical_json', + 'register_serializer', + 'get_serializer', + 'register_activity_adapter', + 'get_activity_adapter', + 'use_activity_adapter', + 'serialize_activity_input', + 'serialize_activity_output', ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py new file mode 100644 index 00000000..a195bc0a --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/__init__.py @@ -0,0 +1,43 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Note: Do not import WorkflowRuntime here to avoid circular imports +# Re-export async context and awaitables +from .async_context import AsyncWorkflowContext # noqa: F401 +from .async_driver import CoroutineOrchestratorRunner # noqa: F401 +from .awaitables import ( # noqa: F401 + ActivityAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async I/O surface for Dapr Workflow extension. + +This package provides explicit async-focused imports that mirror the top-level +exports, improving discoverability and aligning with dapr.aio patterns. +""" + +__all__ = [ + 'AsyncWorkflowContext', + 'CoroutineOrchestratorRunner', + 'ActivityAwaitable', + 'SubOrchestratorAwaitable', + 'SleepAwaitable', + 'WhenAllAwaitable', + 'WhenAnyAwaitable', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py new file mode 100644 index 00000000..ec68dc1c --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_context.py @@ -0,0 +1,189 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Awaitable, Callable, Sequence + +from durabletask import task +from durabletask.aio.awaitables import gather as _dt_gather # type: ignore[import-not-found] +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) + +from .awaitables import ( + ActivityAwaitable, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + +""" +Async workflow context that exposes deterministic awaitables for activities, timers, +external events, and concurrency, along with deterministic utilities. +""" + + +class AsyncWorkflowContext(DeterministicContextMixin): + def __init__(self, base_ctx: task.OrchestrationContext): + self._base_ctx = base_ctx + + # Core workflow metadata parity with sync context + @property + def instance_id(self) -> str: + return self._base_ctx.instance_id + + @property + def current_utc_datetime(self) -> datetime: + return self._base_ctx.current_utc_datetime + + # Activities & Sub-orchestrations + def call_activity( + self, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return ActivityAwaitable( + self._base_ctx, activity_fn, input=input, retry_policy=retry_policy, metadata=metadata + ) + + def call_child_workflow( + self, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ) -> Awaitable[Any]: + return SubOrchestratorAwaitable( + self._base_ctx, + workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + metadata=metadata, + ) + + @property + def is_replaying(self) -> bool: + return self._base_ctx.is_replaying + + # Tracing (engine-provided) pass-throughs when available + @property + def trace_parent(self) -> str | None: + return self._base_ctx.trace_parent + + @property + def trace_state(self) -> str | None: + return self._base_ctx.trace_state + + @property + def workflow_span_id(self) -> str | None: + return self._base_ctx.orchestration_span_id + + @property + def workflow_attempt(self) -> int | None: + getter = getattr(self._base_ctx, 'workflow_attempt', None) + return ( + getter + if isinstance(getter, int) or getter is None + else getattr(self._base_ctx, 'workflow_attempt', None) + ) + + # Timers & Events + def create_timer(self, fire_at: float | timedelta | datetime) -> Awaitable[None]: + # If float provided, interpret as seconds + if isinstance(fire_at, (int, float)): + fire_at = timedelta(seconds=float(fire_at)) + return SleepAwaitable(self._base_ctx, fire_at) + + def sleep(self, duration: float | timedelta | datetime) -> Awaitable[None]: + return self.create_timer(duration) + + def wait_for_external_event(self, name: str) -> Awaitable[Any]: + return ExternalEventAwaitable(self._base_ctx, name) + + # Concurrency + def when_all(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[list[Any]]: + return WhenAllAwaitable(awaitables) + + def when_any(self, awaitables: Sequence[Awaitable[Any]]) -> Awaitable[Any]: + return WhenAnyAwaitable(awaitables) + + def gather(self, *aws: Awaitable[Any], return_exceptions: bool = False) -> Awaitable[list[Any]]: + return _dt_gather(*aws, return_exceptions=return_exceptions) + + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + @property + def is_suspended(self) -> bool: + # Placeholder; will be wired when Durable Task exposes this state in context + return self._base_ctx.is_suspended + + # Pass-throughs for completeness + def set_custom_status(self, custom_status: str) -> None: + if hasattr(self._base_ctx, 'set_custom_status'): + self._base_ctx.set_custom_status(custom_status) + + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + # Try extended signature; fall back to minimal for older fakes/contexts + try: + self._base_ctx.continue_as_new( + new_input, save_events=save_events, carryover_metadata=effective_carryover + ) + except TypeError: + self._base_ctx.continue_as_new(new_input, save_events=save_events) + + # Metadata parity + def set_metadata(self, metadata: dict[str, str] | None) -> None: + setter = getattr(self._base_ctx, 'set_metadata', None) + if callable(setter): + setter(metadata) + + def get_metadata(self) -> dict[str, str] | None: + getter = getattr(self._base_ctx, 'get_metadata', None) + return getter() if callable(getter) else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + + # Execution info parity + @property + def execution_info(self): # type: ignore[override] + return getattr(self._base_ctx, 'execution_info', None) + + +__all__ = [ + 'AsyncWorkflowContext', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py new file mode 100644 index 00000000..7d964174 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/async_driver.py @@ -0,0 +1,95 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any, Awaitable, Callable, Generator, Optional + +from durabletask import task +from durabletask.aio.sandbox import SandboxMode, sandbox_scope + + +class CoroutineOrchestratorRunner: + """Wraps an async orchestrator into a generator-compatible runner.""" + + def __init__( + self, + async_orchestrator: Callable[..., Awaitable[Any]], + *, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + self._async_orchestrator = async_orchestrator + self._sandbox_mode = sandbox_mode + + def to_generator( + self, async_ctx: Any, input_data: Optional[Any] + ) -> Generator[task.Task, Any, Any]: + # Instantiate the coroutine with or without input depending on signature/usage + try: + if input_data is None: + coro = self._async_orchestrator(async_ctx) + else: + coro = self._async_orchestrator(async_ctx, input_data) + except TypeError: + # Fallback for orchestrators that only accept a single ctx arg + coro = self._async_orchestrator(async_ctx) + + # Prime the coroutine + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(None) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(None) + except StopIteration as stop: + return stop.value # type: ignore[misc] + + # Drive the coroutine by yielding the underlying Durable Task(s) + while True: + try: + result = yield awaited + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.send(result) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.send(result) + except StopIteration as stop: + return stop.value + except Exception as exc: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(exc) + except StopIteration as stop: + return stop.value + except BaseException as base_exc: + # Handle cancellation that may not derive from Exception in some environments + try: + import asyncio as _asyncio # local import to avoid hard dep at module import time + + is_cancel = isinstance(base_exc, _asyncio.CancelledError) + except Exception: + is_cancel = False + if is_cancel: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited = coro.throw(base_exc) + else: + with sandbox_scope(async_ctx, self._sandbox_mode): + awaited = coro.throw(base_exc) + except StopIteration as stop: + return stop.value + continue + raise diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py new file mode 100644 index 00000000..27c770d1 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/awaitables.py @@ -0,0 +1,125 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Callable + +from durabletask import task +from durabletask.aio.awaitables import ( + AwaitableBase as _BaseAwaitable, # type: ignore[import-not-found] +) +from durabletask.aio.awaitables import ( + ExternalEventAwaitable as _DTExternalEventAwaitable, +) +from durabletask.aio.awaitables import ( + SleepAwaitable as _DTSleepAwaitable, +) +from durabletask.aio.awaitables import ( + WhenAllAwaitable as _DTWhenAllAwaitable, +) + +AwaitableBase = _BaseAwaitable + + +class ActivityAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + activity_fn: Callable[..., Any], + *, + input: Any = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ): + self._ctx = ctx + self._activity_fn = activity_fn + self._input = input + self._retry_policy = retry_policy + self._metadata = metadata + + def _to_task(self) -> task.Task: + if self._retry_policy is None: + return self._ctx.call_activity( + self._activity_fn, input=self._input, metadata=self._metadata + ) + return self._ctx.call_activity( + self._activity_fn, + input=self._input, + retry_policy=self._retry_policy, + metadata=self._metadata, + ) + + +class SubOrchestratorAwaitable(AwaitableBase): + def __init__( + self, + ctx: Any, + workflow_fn: Callable[..., Any], + *, + input: Any = None, + instance_id: str | None = None, + retry_policy: Any = None, + metadata: dict[str, str] | None = None, + ): + self._ctx = ctx + self._workflow_fn = workflow_fn + self._input = input + self._instance_id = instance_id + self._retry_policy = retry_policy + self._metadata = metadata + + def _to_task(self) -> task.Task: + if self._retry_policy is None: + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + metadata=self._metadata, + ) + return self._ctx.call_child_workflow( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + retry_policy=self._retry_policy, + metadata=self._metadata, + ) + + +class SleepAwaitable(_DTSleepAwaitable): + pass + + +class ExternalEventAwaitable(_DTExternalEventAwaitable): + pass + + +class WhenAllAwaitable(_DTWhenAllAwaitable): + pass + + +class WhenAnyAwaitable(AwaitableBase): + def __init__(self, tasks_like: Iterable[AwaitableBase | task.Task]): + self._tasks_like = list(tasks_like) + + def _to_task(self) -> task.Task: + underlying: list[task.Task] = [] + for a in self._tasks_like: + if isinstance(a, AwaitableBase): + underlying.append(a._to_task()) # type: ignore[attr-defined] + elif isinstance(a, task.Task): + underlying.append(a) + else: + raise TypeError('when_any expects AwaitableBase or durabletask.task.Task') + return task.when_any(underlying) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py new file mode 100644 index 00000000..3d887e17 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/sandbox.py @@ -0,0 +1,233 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio as _asyncio +import random as _random +import time as _time +import uuid as _uuid +from contextlib import ContextDecorator +from typing import Any + +from durabletask.aio.sandbox import SandboxMode +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + +""" +Scoped sandbox patching for async workflows (best-effort, strict). +""" + + +def _ctx_instance_id(async_ctx: Any) -> str: + if hasattr(async_ctx, 'instance_id'): + return getattr(async_ctx, 'instance_id') + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'instance_id'): + return async_ctx._base_ctx.instance_id + return '' + + +def _ctx_now(async_ctx: Any): + if hasattr(async_ctx, 'now'): + try: + return async_ctx.now() + except Exception: + pass + if hasattr(async_ctx, 'current_utc_datetime'): + return async_ctx.current_utc_datetime + if hasattr(async_ctx, '_base_ctx') and hasattr(async_ctx._base_ctx, 'current_utc_datetime'): + return async_ctx._base_ctx.current_utc_datetime + import datetime as _dt + + return _dt.datetime.utcfromtimestamp(0) + + +class _Sandbox(ContextDecorator): + def __init__(self, async_ctx: Any, mode: str): + self._async_ctx = async_ctx + self._mode = mode + self._saved: dict[str, Any] = {} + + def __enter__(self): + self._saved['asyncio.sleep'] = _asyncio.sleep + self._saved['asyncio.gather'] = getattr(_asyncio, 'gather', None) + self._saved['asyncio.create_task'] = getattr(_asyncio, 'create_task', None) + self._saved['random.random'] = _random.random + self._saved['random.randrange'] = _random.randrange + self._saved['random.randint'] = _random.randint + self._saved['uuid.uuid4'] = _uuid.uuid4 + self._saved['time.time'] = _time.time + self._saved['time.time_ns'] = getattr(_time, 'time_ns', None) + + rnd = deterministic_random(_ctx_instance_id(self._async_ctx), _ctx_now(self._async_ctx)) + + async def _sleep_patched(delay: float, result: Any = None): # type: ignore[override] + try: + if float(delay) <= 0: + return await self._saved['asyncio.sleep'](0) + except Exception: + return await self._saved['asyncio.sleep'](delay) # type: ignore[arg-type] + + await self._async_ctx.sleep(delay) + return result + + def _random_patched() -> float: + return rnd.random() + + def _randrange_patched(start, stop=None, step=1): + return rnd.randrange(start, stop, step) if stop is not None else rnd.randrange(start) + + def _randint_patched(a, b): + return rnd.randint(a, b) + + def _uuid4_patched(): + return deterministic_uuid4(rnd) + + def _time_patched() -> float: + return float(_ctx_now(self._async_ctx).timestamp()) + + def _time_ns_patched() -> int: + return int(_ctx_now(self._async_ctx).timestamp() * 1_000_000_000) + + def _create_task_blocked(coro, *args, **kwargs): + try: + close = getattr(coro, 'close', None) + if callable(close): + try: + close() + except Exception: + pass + finally: + raise RuntimeError( + 'asyncio.create_task is not allowed inside workflow (strict mode)' + ) + + def _is_workflow_awaitable(obj: Any) -> bool: + try: + if hasattr(obj, '_to_dapr_task') or hasattr(obj, '_to_task'): + return True + except Exception: + pass + try: + from durabletask import task as _dt + + if isinstance(obj, _dt.Task): + return True + except Exception: + pass + return False + + class _OneShot: + def __init__(self, factory): + self._factory = factory + self._done = False + self._res: Any = None + self._exc: BaseException | None = None + + def __await__(self): # type: ignore[override] + if self._done: + + async def _replay(): + if self._exc is not None: + raise self._exc + return self._res + + return _replay().__await__() + + async def _compute(): + try: + out = await self._factory() + self._res = out + self._done = True + return out + except BaseException as e: # noqa: BLE001 + self._exc = e + self._done = True + raise + + return _compute().__await__() + + def _patched_gather(*aws: Any, return_exceptions: bool = False): # type: ignore[override] + if not aws: + + async def _empty(): + return [] + + return _OneShot(_empty) + + if all(_is_workflow_awaitable(a) for a in aws): + + async def _await_when_all(): + from dapr.ext.workflow.aio.awaitables import WhenAllAwaitable # local import + + combined = WhenAllAwaitable(list(aws)) + return await combined + + return _OneShot(_await_when_all) + + async def _run_mixed(): + results = [] + for a in aws: + try: + results.append(await a) + except Exception as e: # noqa: BLE001 + if return_exceptions: + results.append(e) + else: + raise + return results + + return _OneShot(_run_mixed) + + _asyncio.sleep = _sleep_patched # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = _patched_gather # type: ignore[assignment] + _random.random = _random_patched # type: ignore[assignment] + _random.randrange = _randrange_patched # type: ignore[assignment] + _random.randint = _randint_patched # type: ignore[assignment] + _uuid.uuid4 = _uuid4_patched # type: ignore[assignment] + _time.time = _time_patched # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = _time_ns_patched # type: ignore[assignment] + if self._mode == 'strict' and self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = _create_task_blocked # type: ignore[assignment] + + return self + + def __exit__(self, exc_type, exc, tb): + _asyncio.sleep = self._saved['asyncio.sleep'] # type: ignore[assignment] + if self._saved['asyncio.gather'] is not None: + _asyncio.gather = self._saved['asyncio.gather'] # type: ignore[assignment] + if self._saved['asyncio.create_task'] is not None: + _asyncio.create_task = self._saved['asyncio.create_task'] # type: ignore[assignment] + _random.random = self._saved['random.random'] # type: ignore[assignment] + _random.randrange = self._saved['random.randrange'] # type: ignore[assignment] + _random.randint = self._saved['random.randint'] # type: ignore[assignment] + _uuid.uuid4 = self._saved['uuid.uuid4'] # type: ignore[assignment] + _time.time = self._saved['time.time'] # type: ignore[assignment] + if self._saved['time.time_ns'] is not None: + _time.time_ns = self._saved['time.time_ns'] # type: ignore[assignment] + return False + + +def sandbox_scope(async_ctx: Any, mode: SandboxMode): + if mode == SandboxMode.OFF: + + class _Null(ContextDecorator): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + return _Null() + return _Sandbox(async_ctx, 'strict' if mode == SandboxMode.STRICT else 'best_effort') diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 461bfd43..7fbbbef5 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -19,6 +19,12 @@ from typing import Any, Optional, TypeVar import durabletask.internal.orchestrator_service_pb2 as pb +from dapr.ext.workflow.interceptors import ( + ClientInterceptor, + ScheduleWorkflowRequest, + compose_client_chain, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_context import Workflow @@ -29,7 +35,7 @@ from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings -from dapr.conf.helpers import GrpcEndpoint +from dapr.conf.helpers import GrpcEndpoint, build_grpc_channel_options T = TypeVar('T') TInput = TypeVar('TInput') @@ -51,6 +57,8 @@ def __init__( host: Optional[str] = None, port: Optional[str] = None, logger_options: Optional[LoggerOptions] = None, + *, + interceptors: list[ClientInterceptor] | None = None, ): address = getAddress(host, port) @@ -61,18 +69,31 @@ def __init__( self._logger = Logger('DaprWorkflowClient', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) options = self._logger.get_options() + # Optional gRPC channel options (keepalive, retry policy) via helpers + channel_options = build_grpc_channel_options() + + # Construct base kwargs for TaskHubGrpcClient + base_kwargs = { + 'host_address': uri.endpoint, + 'metadata': metadata, + 'secure_channel': uri.tls, + 'log_handler': options.log_handler, + 'log_formatter': options.log_formatter, + } + + # Initialize TaskHubGrpcClient (DurableTask supports options) self.__obj = client.TaskHubGrpcClient( - host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter, + **base_kwargs, + channel_options=channel_options, ) + # Interceptors + self._client_interceptors: list[ClientInterceptor] = list(interceptors or []) + def schedule_new_workflow( self, workflow: Workflow, @@ -81,6 +102,7 @@ def schedule_new_workflow( instance_id: Optional[str] = None, start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + metadata: dict[str, str] | None = None, ) -> str: """Schedules a new workflow instance for execution. @@ -99,21 +121,33 @@ def schedule_new_workflow( Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): + wf_name = ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + + # Build interceptor chain around schedule call + def terminal(term_req: ScheduleWorkflowRequest) -> str: + payload = wrap_payload_with_metadata(term_req.input, term_req.metadata) return self.__obj.schedule_new_orchestration( - workflow.__dict__['_dapr_alternate_name'], - input=input, - instance_id=instance_id, - start_at=start_at, - reuse_id_policy=reuse_id_policy, + term_req.workflow_name, + input=payload, + instance_id=term_req.instance_id, + start_at=term_req.start_at, + reuse_id_policy=term_req.reuse_id_policy, ) - return self.__obj.schedule_new_orchestration( - workflow.__name__, + + chain = compose_client_chain(self._client_interceptors, terminal) + schedule_req = ScheduleWorkflowRequest( + workflow_name=wf_name, input=input, instance_id=instance_id, start_at=start_at, reuse_id_policy=reuse_id_policy, + metadata=metadata, ) + return chain(schedule_req) def get_workflow_state( self, instance_id: str, *, fetch_payloads: bool = True diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index 714def3f..58fd9a18 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,28 +11,61 @@ limitations under the License. """ +import enum from datetime import datetime, timedelta from typing import Any, Callable, List, Optional, TypeVar, Union +from dapr.ext.workflow.execution_info import WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import unwrap_payload_with_metadata, wrap_payload_with_metadata from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow, WorkflowContext from durabletask import task +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, +) T = TypeVar('T') TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') -class DaprWorkflowContext(WorkflowContext): - """DaprWorkflowContext that provides proxy access to internal OrchestrationContext instance.""" +class Handlers(enum.Enum): + CALL_ACTIVITY = 'call_activity' + CALL_CHILD_WORKFLOW = 'call_child_workflow' + CONTINUE_AS_NEW = 'continue_as_new' + + +class DaprWorkflowContext(WorkflowContext, DeterministicContextMixin): + """Workflow context wrapper with deterministic utilities and metadata helpers. + + Purpose + ------- + - Proxy to the underlying ``durabletask.task.OrchestrationContext`` (engine fields like + ``trace_parent``, ``orchestration_span_id``, and ``workflow_attempt`` pass through). + - Provide SDK-level helpers for durable metadata propagation via interceptors. + - Expose ``execution_info`` as a per-activation snapshot complementing live properties. + + Tips + ---- + - Use ``ctx.get_metadata()/set_metadata()`` to manage outbound propagation. + - Use ``ctx.execution_info.inbound_metadata`` to inspect what arrived on this activation. + - Prefer engine-backed properties for tracing/attempts when available (not yet available in dapr sidecar); fall back to + metadata only for app-specific context. + """ def __init__( - self, ctx: task.OrchestrationContext, logger_options: Optional[LoggerOptions] = None + self, + ctx: task.OrchestrationContext, + logger_options: Optional[LoggerOptions] = None, + *, + outbound_handlers: dict[Handlers, Any] | None = None, ): self.__obj = ctx self._logger = Logger('DaprWorkflowContext', logger_options) + self._outbound_handlers = outbound_handlers or {} + self._metadata: dict[str, str] | None = None # provide proxy access to regular attributes of wrapped object def __getattr__(self, name): @@ -52,10 +83,34 @@ def current_utc_datetime(self) -> datetime: def is_replaying(self) -> bool: return self.__obj.is_replaying + # Deterministic utilities are provided by mixin (now, random, uuid4, new_guid) + + # Metadata API + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + + # Header aliases (ergonomic alias for users familiar with Temporal terminology) + def set_headers(self, headers: dict[str, str] | None) -> None: + self.set_metadata(headers) + + def get_headers(self) -> dict[str, str] | None: + return self.get_metadata() + def set_custom_status(self, custom_status: str) -> None: self._logger.debug(f'{self.instance_id}: Setting custom status to {custom_status}') self.__obj.set_custom_status(custom_status) + # Execution info (populated by runtime when available) + @property + def execution_info(self) -> WorkflowExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: WorkflowExecutionInfo) -> None: + self._execution_info = info + def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: self._logger.debug(f'{self.instance_id}: Creating timer to fire at {fire_at} time') return self.__obj.create_timer(fire_at) @@ -67,6 +122,7 @@ def call_activity( input: TInput = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: # Handle string activity names for cross-app scenarios if isinstance(activity, str): @@ -91,10 +147,18 @@ def call_activity( else: # this case should ideally never happen act = activity.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_ACTIVITY in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_ACTIVITY] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_ACTIVITY]( + self, activity, input, retry_policy, metadata or self.get_metadata() + ) if retry_policy is None: - return self.__obj.call_activity(activity=act, input=input, app_id=app_id) + return self.__obj.call_activity(activity=act, input=transformed_input, app_id=app_id) return self.__obj.call_activity( - activity=act, input=input, retry_policy=retry_policy.obj, app_id=app_id + activity=act, input=transformed_input, retry_policy=retry_policy.obj, app_id=app_id ) def call_child_workflow( @@ -105,6 +169,7 @@ def call_child_workflow( instance_id: Optional[str] = None, retry_policy: Optional[RetryPolicy] = None, app_id: Optional[str] = None, + metadata: dict[str, str] | None = None, ) -> task.Task[TOutput]: # Handle string workflow names for cross-app scenarios if isinstance(workflow, str): @@ -127,8 +192,8 @@ def call_child_workflow( self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') def wf(ctx: task.OrchestrationContext, inp: TInput): - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - return workflow(daprWfContext, inp) + dapr_wf_context = DaprWorkflowContext(ctx, self._logger.get_options()) + return workflow(dapr_wf_context, inp) # copy workflow name so durabletask.worker can find the orchestrator in its registry @@ -137,24 +202,71 @@ def wf(ctx: task.OrchestrationContext, inp: TInput): else: # this case should ideally never happen wf.__name__ = workflow.__name__ + # Apply outbound client interceptor transformations if provided via runtime wiring + transformed_input: Any = input + if Handlers.CALL_CHILD_WORKFLOW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW] + ): + transformed_input = self._outbound_handlers[Handlers.CALL_CHILD_WORKFLOW]( + self, workflow, input, metadata or self.get_metadata() + ) if retry_policy is None: return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, app_id=app_id + wf, input=transformed_input, instance_id=instance_id, app_id=app_id ) return self.__obj.call_sub_orchestrator( - wf, input=input, instance_id=instance_id, retry_policy=retry_policy.obj, app_id=app_id + wf, + input=transformed_input, + instance_id=instance_id, + retry_policy=retry_policy.obj, + app_id=app_id, ) def wait_for_external_event(self, name: str) -> task.Task: self._logger.debug(f'{self.instance_id}: Waiting for external event {name}') return self.__obj.wait_for_external_event(name) - def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: + def continue_as_new( + self, + new_input: Any, + *, + save_events: bool = False, + carryover_metadata: bool | dict[str, str] = False, + carryover_headers: bool | dict[str, str] | None = None, + ) -> None: self._logger.debug(f'{self.instance_id}: Continuing as new') - self.__obj.continue_as_new(new_input, save_events=save_events) + # Allow workflow outbound interceptors (wired via runtime) to modify payload/metadata + transformed_input: Any = new_input + if Handlers.CONTINUE_AS_NEW in self._outbound_handlers and callable( + self._outbound_handlers[Handlers.CONTINUE_AS_NEW] + ): + transformed_input = self._outbound_handlers[Handlers.CONTINUE_AS_NEW]( + self, new_input, self.get_metadata() + ) + + # Merge/carry metadata if requested, unwrapping any envelope produced by interceptors + payload, base_md = unwrap_payload_with_metadata(transformed_input) + # Start with current context metadata; then layer any interceptor-provided metadata on top + current_md = self.get_metadata() or {} + effective_md = {**current_md, **(base_md or {})} + effective_carryover = ( + carryover_headers if carryover_headers is not None else carryover_metadata + ) + if effective_carryover: + base = effective_md or {} + if isinstance(effective_carryover, dict): + md = {**base, **effective_carryover} + else: + md = base + payload = wrap_payload_with_metadata(payload, md) + else: + # If we had metadata from interceptors or context, preserve it + if effective_md: + payload = wrap_payload_with_metadata(payload, effective_md) + self.__obj.continue_as_new(payload, save_events=save_events) -def when_all(tasks: List[task.Task[T]]) -> task.WhenAllTask[T]: +def when_all(tasks: List[task.Task]) -> task.WhenAllTask: """Returns a task that completes when all of the provided tasks complete or when one of the tasks fail.""" return task.when_all(tasks) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py new file mode 100644 index 00000000..d33a02c6 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/deterministic.py @@ -0,0 +1,27 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +# Backward-compatible shim: import deterministic utilities from durabletask +from durabletask.deterministic import ( # type: ignore[F401] + DeterministicContextMixin, + deterministic_random, + deterministic_uuid4, +) + +__all__ = [ + 'DeterministicContextMixin', + 'deterministic_random', + 'deterministic_uuid4', +] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py new file mode 100644 index 00000000..0aacd710 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/execution_info.py @@ -0,0 +1,49 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +""" +Minimal, deterministic snapshots of inbound durable metadata. + +Rationale +--------- + +Execution info previously mirrored many engine fields (IDs, tracing, attempts) already +available on the workflow/activity contexts. To remove redundancy and simplify usage, the +execution info types now only capture the durable ``inbound_metadata`` that was actually +propagated into this activation. Use context properties directly for engine fields. +""" + + +@dataclass +class WorkflowExecutionInfo: + """Per-activation snapshot for workflows. + + Only includes ``inbound_metadata`` that arrived with this activation. + """ + + inbound_metadata: dict[str, str] + + +@dataclass +class ActivityExecutionInfo: + """Per-activation snapshot for activities. + + Only includes ``inbound_metadata`` that arrived with this activity invocation. + """ + + inbound_metadata: dict[str, str] + activity_name: str diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py new file mode 100644 index 00000000..90ebb7ea --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/interceptors.py @@ -0,0 +1,414 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Generic, Protocol, TypeVar + +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_context import WorkflowContext + +# Type variables for generic interceptor payload typing +TInput = TypeVar('TInput') +TWorkflowInput = TypeVar('TWorkflowInput') +TActivityInput = TypeVar('TActivityInput') + +""" +Interceptor interfaces and chain utilities for the Dapr Workflow SDK. + +Providing a single enter/exit around calls. + +IMPORTANT: Generator wrappers for async workflows +-------------------------------------------------- +When writing runtime interceptors that touch workflow execution, be careful with generator +handling. If an interceptor obtains a workflow generator from user code (e.g., an async +orchestrator adapted into a generator) it must not manually iterate it using a for-loop +and yield the produced items. Doing so breaks send()/throw() propagation back into the +inner generator, which can cause resumed results from the durable runtime to be dropped +and appear as None to awaiters. + +Best practices: +- If the interceptor participates in composition and needs to return the generator, + return it directly (do not iterate it). +- If the interceptor must wrap the generator, always use "yield from inner_gen" so that + send()/throw() are forwarded correctly. + +Context managers with async workflows +-------------------------------------- +When using context managers (like ExitStack, logging contexts, or trace contexts) in an +interceptor for async workflows, be aware that calling `next(input)` returns a generator +object immediately, NOT the final result. The generator executes later when the durable +task runtime drives it. + +If you need a context manager to remain active during the workflow execution: + +**WRONG - Context exits before workflow runs:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + with setup_context(): + return next(input) # Returns generator, context exits immediately! + +**CORRECT - Context stays active throughout execution:** + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with setup_context(): + gen = next(input) + yield from gen # Keep context alive while generator executes + return wrapper() + +For more complex scenarios with ExitStack or async context managers, wrap the generator +with `yield from` to ensure your context spans the entire workflow execution, including +all replay and continuation events. + +Example with ExitStack: + + def execute_workflow(self, input: ExecuteWorkflowRequest, next): + def wrapper(): + with ExitStack() as stack: + # Set up contexts (trace, logging, etc.) + stack.enter_context(trace_context(...)) + stack.enter_context(logging_context(...)) + + # Get the generator from the next interceptor/handler + gen = next(input) + + # Keep contexts alive while generator executes + yield from gen + return wrapper() + +This pattern ensures your context manager remains active during: +- Initial workflow execution +- Replays from durable state +- Continuation after awaits +- Activity calls and child workflow invocations +""" + + +# Context metadata propagation +# ---------------------------- +# "metadata" is a durable, string-only map. It is serialized on the wire and propagates across +# boundaries (client → runtime → activity/child), surviving replays/retries. Use it when downstream +# components must observe the value. In-process ephemeral state should be handled within interceptors +# without attempting to propagate across process boundaries. + + +# ------------------------------ +# Client-side interceptor surface +# ------------------------------ + + +@dataclass +class ScheduleWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + start_at: Any | None + reuse_id_policy: Any | None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallChildWorkflowRequest(Generic[TInput]): + workflow_name: str + input: TInput + instance_id: str | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class ContinueAsNewRequest(Generic[TInput]): + input: TInput + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +@dataclass +class CallActivityRequest(Generic[TInput]): + activity_name: str + input: TInput + retry_policy: Any | None + # Optional workflow context for outbound calls made inside workflows + workflow_ctx: Any | None = None + # Durable context serialized and propagated across boundaries + metadata: dict[str, str] | None = None + + +class ClientInterceptor(Protocol, Generic[TInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: ... + + +# ------------------------------- +# Runtime-side interceptor surface +# ------------------------------- + + +@dataclass +class ExecuteWorkflowRequest(Generic[TInput]): + ctx: WorkflowContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +@dataclass +class ExecuteActivityRequest(Generic[TInput]): + ctx: WorkflowActivityContext + input: TInput + # Durable metadata (runtime chain only; not injected into user code) + metadata: dict[str, str] | None = None + + +class RuntimeInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +# ------------------------------ +# Convenience base classes (devex) +# ------------------------------ + + +class BaseClientInterceptor(Generic[TInput]): + """Subclass this to get method name completion and safe defaults. + + Override any of the methods to customize behavior. By default, these + methods simply call `next` unchanged. + """ + + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[TInput], + next: Callable[[ScheduleWorkflowRequest[TInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + # No workflow-outbound methods here; use WorkflowOutboundInterceptor for those + + +class BaseRuntimeInterceptor(Generic[TWorkflowInput, TActivityInput]): + """Subclass this to get method name completion and safe defaults.""" + + def execute_workflow( + self, + input: ExecuteWorkflowRequest[TWorkflowInput], + next: Callable[[ExecuteWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[TActivityInput], + next: Callable[[ExecuteActivityRequest[TActivityInput]], Any], + ) -> Any: # noqa: D401 + return next(input) + + +# ------------------------------ +# Helper: chain composition +# ------------------------------ + + +def compose_client_chain( + interceptors: list[ClientInterceptor], terminal: Callable[[Any], Any] +) -> Callable[[Any], Any]: + """Compose client interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., scheduling the workflow) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: ClientInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ScheduleWorkflowRequest): + return curr_icpt.schedule_new_workflow(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Workflow outbound interceptor surface +# ------------------------------ + + +class WorkflowOutboundInterceptor(Protocol, Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: ... + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: ... + + +class BaseWorkflowOutboundInterceptor(Generic[TWorkflowInput, TActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[TWorkflowInput], + next: Callable[[CallChildWorkflowRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def continue_as_new( + self, + input: ContinueAsNewRequest[TWorkflowInput], + next: Callable[[ContinueAsNewRequest[TWorkflowInput]], Any], + ) -> Any: + return next(input) + + def call_activity( + self, + input: CallActivityRequest[TActivityInput], + next: Callable[[CallActivityRequest[TActivityInput]], Any], + ) -> Any: + return next(input) + + +# ------------------------------ +# Backward-compat typing aliases +# ------------------------------ + + +def compose_workflow_outbound_chain( + interceptors: list[WorkflowOutboundInterceptor], + terminal: Callable[[Any], Any], +) -> Callable[[Any], Any]: + """Compose workflow outbound interceptors into a single callable. + + Interceptors are applied in list order; each receives a ``next``. + The ``terminal`` callable is the final handler invoked after all interceptors; it + performs the base operation (e.g., preparing outbound call args) when the chain ends. + """ + next_fn = terminal + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: WorkflowOutboundInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + # Dispatch to the appropriate outbound method on the interceptor + if isinstance(input, CallActivityRequest): + return curr_icpt.call_activity(input, nxt) + if isinstance(input, CallChildWorkflowRequest): + return curr_icpt.call_child_workflow(input, nxt) + if isinstance(input, ContinueAsNewRequest): + return curr_icpt.continue_as_new(input, nxt) + # Fallback to next if input type unknown + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn + + +# ------------------------------ +# Helper: envelope for durable metadata +# ------------------------------ + +_META_KEY = '__dapr_meta__' +_META_VERSION = 1 +_PAYLOAD_KEY = '__dapr_payload__' + + +def wrap_payload_with_metadata(payload: Any, metadata: dict[str, str] | None) -> Any: + """If metadata is provided and non-empty, wrap payload in an envelope for persistence. + + Backward compatible: if metadata is falsy, return payload unchanged. + """ + if metadata: + return { + _META_KEY: { + 'v': _META_VERSION, + 'metadata': metadata, + }, + _PAYLOAD_KEY: payload, + } + return payload + + +def unwrap_payload_with_metadata(obj: Any) -> tuple[Any, dict[str, str] | None]: + """Extract payload and metadata from envelope if present. + + Returns (payload, metadata_dict_or_none). + """ + try: + if isinstance(obj, dict) and _META_KEY in obj and _PAYLOAD_KEY in obj: + meta = obj.get(_META_KEY) or {} + md = meta.get('metadata') if isinstance(meta, dict) else None + return obj.get(_PAYLOAD_KEY), md if isinstance(md, dict) else None + except Exception: + # Be robust: on any error, treat as raw payload + pass + return obj, None + + +def compose_runtime_chain( + interceptors: list[RuntimeInterceptor], final_handler: Callable[[Any], Any] +): + """Compose runtime interceptors into a single callable (synchronous). + + The ``final_handler`` callable is the final handler invoked after all interceptors; it + performs the core operation (e.g., calling user workflow/activity or returning a + workflow generator) when the chain ends. + """ + next_fn = final_handler + for icpt in reversed(interceptors or []): + + def make_next(curr_icpt: RuntimeInterceptor, nxt: Callable[[Any], Any]): + def runner(input: Any) -> Any: + if isinstance(input, ExecuteWorkflowRequest): + return curr_icpt.execute_workflow(input, nxt) + if isinstance(input, ExecuteActivityRequest): + return curr_icpt.execute_activity(input, nxt) + return nxt(input) + + return runner + + next_fn = make_next(icpt, next_fn) + return next_fn diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py new file mode 100644 index 00000000..09af5188 --- /dev/null +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/serializers.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import json +from collections.abc import MutableMapping, MutableSequence +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + cast, +) + +""" +General-purpose, provider-agnostic JSON serialization helpers for workflow activities. + +This module focuses on generic extension points to ensure activity inputs/outputs are JSON-only +and replay-safe. It intentionally avoids provider-specific shapes (e.g., model/tool contracts), +which should live in examples or external packages. +""" + + +def _is_json_primitive(value: Any) -> bool: + return value is None or isinstance(value, (str, int, float, bool)) + + +def _to_json_safe(value: Any, *, strict: bool) -> Any: + """Convert a Python object to a JSON-serializable structure. + + - Dict keys become strings (lenient) or error (strict) if not str. + - Unsupported values become str(value) (lenient) or error (strict). + """ + + if _is_json_primitive(value): + return value + + if isinstance(value, MutableSequence) or isinstance(value, tuple): + return [_to_json_safe(v, strict=strict) for v in value] + + if isinstance(value, MutableMapping) or isinstance(value, dict): + output: Dict[str, Any] = {} + for k, v in value.items(): + if not isinstance(k, str): + if strict: + raise ValueError('dict keys must be strings in strict mode') + k = str(k) + output[k] = _to_json_safe(v, strict=strict) + return output + + if strict: + # Attempt final json.dumps to surface type + try: + json.dumps(value) + except Exception as err: + raise ValueError(f'non-JSON-serializable value: {type(value).__name__}') from err + return value + + return str(value) + + +def _ensure_json(obj: Any, *, strict: bool) -> Any: + converted = _to_json_safe(obj, strict=strict) + # json.dumps as a final guard + json.dumps(converted) + return converted + + +# ---------------------------------------------------------------------------------------------- +# Generic helpers and extension points +# ---------------------------------------------------------------------------------------------- + + +class CanonicalSerializable(Protocol): + """Objects implementing this can produce a canonical JSON-serializable structure.""" + + def to_canonical_json(self, *, strict: bool = True) -> Any: ... + + +class GenericSerializer(Protocol): + """Serializer that converts arbitrary Python objects to/from JSON-serializable data.""" + + def serialize(self, obj: Any, *, strict: bool = True) -> Any: ... + + def deserialize(self, data: Any) -> Any: ... + + +_SERIALIZERS: Dict[str, GenericSerializer] = {} + + +def register_serializer(name: str, serializer: GenericSerializer) -> None: + if not name: + raise ValueError('serializer name must be non-empty') + _SERIALIZERS[name] = serializer + + +def get_serializer(name: str) -> Optional[GenericSerializer]: + return _SERIALIZERS.get(name) + + +def ensure_canonical_json(obj: Any, *, strict: bool = True) -> Any: + """Ensure any object is converted into a JSON-serializable structure. + + - If the object implements CanonicalSerializable, call to_canonical_json + - Else, coerce via the internal JSON-safe conversion + """ + + if hasattr(obj, 'to_canonical_json') and callable(getattr(obj, 'to_canonical_json')): + return _ensure_json( + cast(CanonicalSerializable, obj).to_canonical_json(strict=strict), strict=strict + ) + return _ensure_json(obj, strict=strict) + + +class ActivityIOAdapter(Protocol): + """Adapter to control how activity inputs/outputs are serialized.""" + + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: ... + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: ... + + +_ACTIVITY_ADAPTERS: Dict[str, ActivityIOAdapter] = {} + + +def register_activity_adapter(name: str, adapter: ActivityIOAdapter) -> None: + if not name: + raise ValueError('activity adapter name must be non-empty') + _ACTIVITY_ADAPTERS[name] = adapter + + +def get_activity_adapter(name: str) -> Optional[ActivityIOAdapter]: + return _ACTIVITY_ADAPTERS.get(name) + + +def use_activity_adapter( + adapter: ActivityIOAdapter, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to attach an ActivityIOAdapter to an activity function.""" + + def _decorate(f: Callable[..., Any]) -> Callable[..., Any]: + cast(Any, f).__dapr_activity_io_adapter__ = adapter + return f + + return _decorate + + +def serialize_activity_input(func: Callable[..., Any], input: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_input(input, strict=strict) + return ensure_canonical_json(input, strict=strict) + + +def serialize_activity_output(func: Callable[..., Any], output: Any, *, strict: bool = True) -> Any: + adapter = getattr(func, '__dapr_activity_io_adapter__', None) + if adapter: + return cast(ActivityIOAdapter, adapter).serialize_output(output, strict=strict) + return ensure_canonical_json(output, strict=strict) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py index 331ad6c2..79c53ab8 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py @@ -17,6 +17,7 @@ from typing import Callable, TypeVar +from dapr.ext.workflow.execution_info import ActivityExecutionInfo from durabletask import task T = TypeVar('T') @@ -25,10 +26,20 @@ class WorkflowActivityContext: - """Defines properties and methods for task activity context objects.""" + """Wrapper for ``durabletask.task.ActivityContext`` with metadata helpers. + + Purpose + ------- + - Provide pass-throughs for engine fields (``trace_parent``, ``trace_state``, + and parent ``workflow_span_id`` when available). + - Surface ``execution_info``: a per-activation snapshot that includes the + ``inbound_metadata`` actually received for this activity. + - Offer ``get_metadata()/set_metadata()`` for SDK-level durable metadata management. + """ def __init__(self, ctx: task.ActivityContext): self.__obj = ctx + self._metadata: dict[str, str] | None = None @property def workflow_id(self) -> str: @@ -43,6 +54,20 @@ def task_id(self) -> int: def get_inner_context(self) -> task.ActivityContext: return self.__obj + @property + def execution_info(self) -> ActivityExecutionInfo | None: + return getattr(self, '_execution_info', None) + + def _set_execution_info(self, info: ActivityExecutionInfo) -> None: + self._execution_info = info + + # Metadata accessors (SDK-level; set by runtime inbound if available) + def set_metadata(self, metadata: dict[str, str] | None) -> None: + self._metadata = dict(metadata) if metadata else None + + def get_metadata(self) -> dict[str, str] | None: + return dict(self._metadata) if self._metadata else None + # Activities are simple functions that can be scheduled by workflows Activity = Callable[..., TOutput] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index 8453e16e..75efc8f5 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,8 +15,9 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Optional, TypeVar, Union +from typing import Any, Callable, Generator, TypeVar, Union +from dapr.ext.workflow.retry_policy import RetryPolicy from dapr.ext.workflow.workflow_activity_context import Activity from durabletask import task @@ -90,7 +89,7 @@ def set_custom_status(self, custom_status: str) -> None: pass @abstractmethod - def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: + def create_timer(self, fire_at: datetime | timedelta) -> task.Task: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -110,8 +109,8 @@ def call_activity( self, activity: Union[Activity[TOutput], str], *, - input: Optional[TInput] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + app_id: str | None = None, ) -> task.Task[TOutput]: """Schedule an activity for execution. @@ -123,6 +122,8 @@ def call_activity( The JSON-serializable input (or None) to pass to the activity. app_id: str | None The AppID that will execute the activity. + return_type: task.Task[TOutput] + The JSON-serializable output type to expect from the activity result. Returns ------- @@ -136,9 +137,10 @@ def call_child_workflow( self, orchestrator: Union[Workflow[TOutput], str], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - app_id: Optional[str] = None, + input: TInput | None = None, + instance_id: str | None = None, + app_id: str | None = None, + retry_policy: RetryPolicy | None = None, ) -> task.Task[TOutput]: """Schedule child-workflow function for execution. @@ -153,6 +155,9 @@ def call_child_workflow( random UUID will be used. app_id: str The AppID that will execute the workflow. + retry_policy: RetryPolicy | None + Optional retry policy for the child-workflow. When provided, failures will be retried + according to the policy. Returns ------- diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 593e55c6..a7098f63 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -13,17 +13,34 @@ limitations under the License. """ +import asyncio import inspect +import traceback from functools import wraps -from typing import Optional, Sequence, TypeVar, Union +from typing import Any, Awaitable, Callable, List, Optional, Sequence, TypeVar, Union import grpc -from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext, Handlers +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.interceptors import ( + CallActivityRequest, + CallChildWorkflowRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowOutboundInterceptor, + compose_runtime_chain, + compose_workflow_outbound_chain, + unwrap_payload_with_metadata, + wrap_payload_with_metadata, +) from dapr.ext.workflow.logger import Logger, LoggerOptions from dapr.ext.workflow.util import getAddress from dapr.ext.workflow.workflow_activity_context import Activity, WorkflowActivityContext from dapr.ext.workflow.workflow_context import Workflow from durabletask import task, worker +from durabletask.aio.sandbox import SandboxMode from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER @@ -47,16 +64,19 @@ class WorkflowRuntime: def __init__( self, - host: Optional[str] = None, - port: Optional[str] = None, + host: str | None = None, + port: str | None = None, logger_options: Optional[LoggerOptions] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None, maximum_concurrent_activity_work_items: Optional[int] = None, maximum_concurrent_orchestration_work_items: Optional[int] = None, maximum_thread_pool_workers: Optional[int] = None, + *, + runtime_interceptors: Optional[list[RuntimeInterceptor]] = None, + workflow_outbound_interceptors: Optional[list[WorkflowOutboundInterceptor]] = None, ): self._logger = Logger('WorkflowRuntime', logger_options) - metadata = tuple() + metadata = () if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) address = getAddress(host, port) @@ -80,16 +100,128 @@ def __init__( maximum_thread_pool_workers=maximum_thread_pool_workers, ), ) + # Interceptors + self._runtime_interceptors: List[RuntimeInterceptor] = list(runtime_interceptors or []) + self._workflow_outbound_interceptors: List[WorkflowOutboundInterceptor] = list( + workflow_outbound_interceptors or [] + ) + + # Outbound helpers apply interceptors and wrap metadata; no built-in transformations. + def _apply_outbound_activity( + self, + ctx: Any, + activity: Callable[..., Any] | str, + input: Any, + retry_policy: Any | None, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform CallActivityRequest + name = ( + activity + if isinstance(activity, str) + else ( + activity.__dict__['_dapr_alternate_name'] + if hasattr(activity, '_dapr_alternate_name') + else activity.__name__ + ) + ) + + def terminal(term_req: CallActivityRequest) -> CallActivityRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + # Use per-context default metadata when not provided + metadata = metadata or ctx.get_metadata() + act_req = CallActivityRequest( + activity_name=name, + input=input, + retry_policy=retry_policy, + workflow_ctx=ctx, + metadata=metadata, + ) + out = chain(act_req) + if isinstance(out, CallActivityRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_child( + self, + ctx: Any, + workflow: Callable[..., Any] | str, + input: Any, + metadata: dict[str, str] | None = None, + ): + name = ( + workflow + if isinstance(workflow, str) + else ( + workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, '_dapr_alternate_name') + else workflow.__name__ + ) + ) + + def terminal(term_req: CallChildWorkflowRequest) -> CallChildWorkflowRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + child_req = CallChildWorkflowRequest( + workflow_name=name, input=input, instance_id=None, workflow_ctx=ctx, metadata=metadata + ) + out = chain(child_req) + if isinstance(out, CallChildWorkflowRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return input + + def _apply_outbound_continue_as_new( + self, + ctx: Any, + new_input: Any, + metadata: dict[str, str] | None = None, + ): + # Build workflow-outbound chain to transform ContinueAsNewRequest + from dapr.ext.workflow.interceptors import ContinueAsNewRequest + + def terminal(term_req: ContinueAsNewRequest) -> ContinueAsNewRequest: + return term_req + + chain = compose_workflow_outbound_chain(self._workflow_outbound_interceptors, terminal) + metadata = metadata or ctx.get_metadata() + cnr = ContinueAsNewRequest(input=new_input, workflow_ctx=ctx, metadata=metadata) + out = chain(cnr) + if isinstance(out, ContinueAsNewRequest): + return wrap_payload_with_metadata(out.input, out.metadata) + return new_input def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): + # Seamlessly support async workflows using the existing API + if inspect.iscoroutinefunction(fn): + return self.register_async_workflow(fn, name=name) + self._logger.info(f"Registering workflow '{fn.__name__}' with runtime") - def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): - """Responsible to call Workflow function in orchestrationWrapper""" - daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) - if inp is None: - return fn(daprWfContext) - return fn(daprWfContext, inp) + def orchestration_wrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + dapr_wf_context = self._get_workflow_context(ctx, md) + + # Build interceptor chain; terminal calls the user function (generator or non-generator) + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + try: + return ( + fn(dapr_wf_context) + if exec_req.input is None + else fn(dapr_wf_context, exec_req.input) + ) + except Exception as exc: # log and re-raise to surface failure details + self._logger.error( + f"{ctx.instance_id}: workflow '{fn.__name__}' raised {type(exc).__name__}: {exc}\n{traceback.format_exc()}" + ) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=dapr_wf_context, input=payload, metadata=md)) if hasattr(fn, '_workflow_registered'): # whenever a workflow is registered, it has a _dapr_alternate_name attribute @@ -104,7 +236,7 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_orchestrator( - fn.__dict__['_dapr_alternate_name'], orchestrationWrapper + fn.__dict__['_dapr_alternate_name'], orchestration_wrapper ) fn.__dict__['_workflow_registered'] = True @@ -114,12 +246,44 @@ def register_activity(self, fn: Activity, *, name: Optional[str] = None): """ self._logger.info(f"Registering activity '{fn.__name__}' with runtime") - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - wfActivityContext = WorkflowActivityContext(ctx) - if inp is None: - return fn(wfActivityContext) - return fn(wfActivityContext, inp) + def activity_wrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): + """Activity entrypoint wrapped by runtime interceptors.""" + wf_activity_context = WorkflowActivityContext(ctx) + payload, md = unwrap_payload_with_metadata(inp) + # Populate inbound metadata onto activity context + wf_activity_context.set_metadata(md or {}) + + # Populate execution info + try: + # Determine activity name (registered alternate name or function __name__) + act_name = getattr(fn, '_dapr_alternate_name', fn.__name__) + ainfo = ActivityExecutionInfo(inbound_metadata=md or {}, activity_name=act_name) + wf_activity_context._set_execution_info(ainfo) + except Exception: + pass + + def final_handler(exec_req: ExecuteActivityRequest) -> Any: + try: + # Support async and sync activities + if inspect.iscoroutinefunction(fn): + if exec_req.input is None: + return asyncio.run(fn(wf_activity_context)) + return asyncio.run(fn(wf_activity_context, exec_req.input)) + if exec_req.input is None: + return fn(wf_activity_context) + return fn(wf_activity_context, exec_req.input) + except Exception as exc: + # Log details for troubleshooting (metadata, error type) + self._logger.error( + f"{ctx.orchestration_id}:{ctx.task_id} activity '{fn.__name__}' failed with {type(exc).__name__}: {exc}" + ) + self._logger.error(traceback.format_exc()) + raise + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain( + ExecuteActivityRequest(ctx=wf_activity_context, input=payload, metadata=md) + ) if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -134,7 +298,7 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper + fn.__dict__['_dapr_alternate_name'], activity_wrapper ) fn.__dict__['_activity_registered'] = True @@ -144,7 +308,13 @@ def start(self): def shutdown(self): """Stops the listening for work items on a background thread.""" - self.__worker.stop() + try: + self._logger.info('Stopping gRPC worker...') + self.__worker.stop() + self._logger.info('Worker shutdown completed') + except Exception as exc: # pragma: no cover + # DurableTask worker may emit CANCELLED warnings during local shutdown; not fatal + self._logger.warning(f'Worker stop encountered {type(exc).__name__}: {exc}') def workflow(self, __fn: Workflow = None, *, name: Optional[str] = None): """Decorator to register a workflow function. @@ -174,7 +344,11 @@ def add(ctx, x: int, y: int) -> int: """ def wrapper(fn: Workflow): - self.register_workflow(fn, name=name) + # Auto-detect coroutine and delegate to async registration + if inspect.iscoroutinefunction(fn): + self.register_async_workflow(fn, name=name) + else: + self.register_workflow(fn, name=name) @wraps(fn) def innerfn(): @@ -194,6 +368,121 @@ def innerfn(): return wrapper + # Async orchestrator registration (additive) + def register_async_workflow( + self, + fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]], + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ) -> None: + """Register an async workflow function. + + The async workflow is wrapped by a coroutine-to-generator driver so it can be + executed by the Durable Task runtime alongside existing generator workflows. + + Args: + fn: The async workflow function, taking ``AsyncWorkflowContext`` and optional input. + name: Optional alternate name for registration. + sandbox_mode: Scoped compatibility patching mode. + """ + self._logger.info(f"Registering ASYNC workflow '{fn.__name__}' with runtime") + + if hasattr(fn, '_workflow_registered'): + alt_name = fn.__dict__['_dapr_alternate_name'] + raise ValueError(f'Workflow {fn.__name__} already registered as {alt_name}') + if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn._dapr_alternate_name + if name is not None: + m = f'Workflow {fn.__name__} already has an alternate name {alt_name}' + raise ValueError(m) + else: + fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + + runner = CoroutineOrchestratorRunner(fn, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, inp: Optional[Any] = None): + """Orchestration entrypoint wrapped by runtime interceptors.""" + payload, md = unwrap_payload_with_metadata(inp) + base_ctx = self._get_workflow_context(ctx, md) + + async_ctx = AsyncWorkflowContext(base_ctx) + + def final_handler(exec_req: ExecuteWorkflowRequest) -> Any: + # Build the generator using the (potentially shaped) input from interceptors. + shaped_input = exec_req.input + return runner.to_generator(async_ctx, shaped_input) + + chain = compose_runtime_chain(self._runtime_interceptors, final_handler) + return chain(ExecuteWorkflowRequest(ctx=async_ctx, input=payload, metadata=md)) + + self.__worker._registry.add_named_orchestrator( + fn.__dict__['_dapr_alternate_name'], generator_orchestrator + ) + fn.__dict__['_workflow_registered'] = True + + def _get_workflow_context( + self, ctx: task.OrchestrationContext, metadata: dict[str, str] | None = None + ) -> DaprWorkflowContext: + """Get the workflow context and execution info for the given orchestration context and metadata. + Execution info serves as a read-only snapshot of the workflow context. + + Args: + ctx: The orchestration context. + metadata: The metadata for the workflow. + + Returns: + The workflow context. + """ + base_ctx = DaprWorkflowContext( + ctx, + self._logger.get_options(), + outbound_handlers={ + Handlers.CALL_ACTIVITY: self._apply_outbound_activity, + Handlers.CALL_CHILD_WORKFLOW: self._apply_outbound_child, + Handlers.CONTINUE_AS_NEW: self._apply_outbound_continue_as_new, + }, + ) + # Populate minimal execution info (only inbound metadata) + info = WorkflowExecutionInfo(inbound_metadata=metadata or {}) + base_ctx._set_execution_info(info) + base_ctx.set_metadata(metadata or {}) + return base_ctx + + def async_workflow( + self, + __fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: SandboxMode = SandboxMode.OFF, + ): + """Decorator to register an async workflow function. + + Usage: + @runtime.async_workflow(name="my_wf") + async def my_wf(ctx: AsyncWorkflowContext, input): + ... + """ + + def wrapper(fn: Callable[[AsyncWorkflowContext, Any], Awaitable[Any]]): + self.register_async_workflow(fn, name=name, sandbox_mode=sandbox_mode) + + @wraps(fn) + def innerfn(): + return fn + + if hasattr(fn, '_dapr_alternate_name'): + innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + else: + innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__signature__ = inspect.signature(fn) + return innerfn + + if __fn: + return wrapper(__fn) + + return wrapper + def activity(self, __fn: Activity = None, *, name: Optional[str] = None): """Decorator to register an activity function. diff --git a/ext/dapr-ext-workflow/examples/generics_interceptors_example.py b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py new file mode 100644 index 00000000..7ef9c8f4 --- /dev/null +++ b/ext/dapr-ext-workflow/examples/generics_interceptors_example.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import os +from dataclasses import asdict, dataclass +from typing import List + +from dapr.ext.workflow import ( + DaprWorkflowClient, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ContinueAsNewRequest, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + ScheduleWorkflowRequest, +) + +# ------------------------------ +# Typed payloads carried by interceptors +# ------------------------------ + + +@dataclass +class MyWorkflowInput: + question: str + tags: List[str] + + +@dataclass +class MyActivityInput: + name: str + count: int + + +# ------------------------------ +# Interceptors with generics + minimal (de)serialization +# ------------------------------ + + +class MyClientInterceptor(BaseClientInterceptor[MyWorkflowInput]): + def schedule_new_workflow( + self, + input: ScheduleWorkflowRequest[MyWorkflowInput], + nxt, + ) -> str: + # Ensure wire format is JSON-serializable (dict) + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ScheduleWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + start_at=input.start_at, + reuse_id_policy=input.reuse_id_policy, + metadata=input.metadata, + ) + return nxt(shaped) + + +class MyRuntimeInterceptor(BaseRuntimeInterceptor[MyWorkflowInput, MyActivityInput]): + def execute_workflow( + self, + input: ExecuteWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert inbound dict into typed model for workflow code + data = input.input + if isinstance(data, dict) and 'question' in data: + input.input = MyWorkflowInput( + question=data.get('question', ''), tags=list(data.get('tags', [])) + ) # type: ignore[assignment] + return nxt(input) + + def execute_activity( + self, + input: ExecuteActivityRequest[MyActivityInput], + nxt, + ): + data = input.input + if isinstance(data, dict) and 'name' in data: + input.input = MyActivityInput( + name=data.get('name', ''), count=int(data.get('count', 0)) + ) # type: ignore[assignment] + return nxt(input) + + +class MyOutboundInterceptor(BaseWorkflowOutboundInterceptor[MyWorkflowInput, MyActivityInput]): + def call_child_workflow( + self, + input: CallChildWorkflowRequest[MyWorkflowInput], + nxt, + ): + # Convert typed payload back to wire before sending + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallChildWorkflowRequest[MyWorkflowInput]( + workflow_name=input.workflow_name, + input=payload, # type: ignore[arg-type] + instance_id=input.instance_id, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def continue_as_new( + self, + input: ContinueAsNewRequest[MyWorkflowInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = ContinueAsNewRequest[MyWorkflowInput]( + input=payload, # type: ignore[arg-type] + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + def call_activity( + self, + input: CallActivityRequest[MyActivityInput], + nxt, + ): + payload = ( + asdict(input.input) if hasattr(input.input, '__dataclass_fields__') else input.input + ) + shaped = CallActivityRequest[MyActivityInput]( + activity_name=input.activity_name, + input=payload, # type: ignore[arg-type] + retry_policy=input.retry_policy, + workflow_ctx=input.workflow_ctx, + metadata=input.metadata, + ) + return nxt(shaped) + + +# ------------------------------ +# Minimal runnable example with sidecar +# ------------------------------ + + +def main() -> None: + # Expect DAPR_GRPC_ENDPOINT (e.g., dns:127.0.0.1:56179) to be set for local sidecar/dev hub + ep = os.getenv('DAPR_GRPC_ENDPOINT') + if not ep: + print('WARNING: DAPR_GRPC_ENDPOINT not set; default sidecar address will be used') + + # Build runtime with interceptors + runtime = WorkflowRuntime( + runtime_interceptors=[MyRuntimeInterceptor()], + workflow_outbound_interceptors=[MyOutboundInterceptor()], + ) + + # Register a simple activity + @runtime.activity(name='greet') + def greet(_ctx, x: dict | None = None) -> str: # wire format at activity boundary is dict + x = x or {} + return f'Hello {x.get("name", "world")} x{x.get("count", 0)}' + + # Register an async workflow that calls the activity once + @runtime.async_workflow(name='wf_greet') + async def wf_greet(ctx, arg: MyWorkflowInput | dict | None = None): + # At this point, runtime interceptor converted inbound to MyWorkflowInput + if isinstance(arg, MyWorkflowInput): + act_in = MyActivityInput(name=arg.question, count=len(arg.tags)) + else: + # Fallback if interceptor not present + d = arg or {} + act_in = MyActivityInput(name=str(d.get('question', '')), count=len(d.get('tags', []))) + return await ctx.call_activity('greet', input=asdict(act_in)) + + runtime.start() + try: + # Client with client-side interceptor for schedule typing + client = DaprWorkflowClient(interceptors=[MyClientInterceptor()]) + wf_input = MyWorkflowInput(question='World', tags=['a', 'b']) + instance_id = client.schedule_new_workflow(wf_greet, input=wf_input) + print('Started instance:', instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + print('Final status:', getattr(st, 'runtime_status', None)) + if st: + print('Output:', st.to_json().get('serialized_output')) + finally: + runtime.shutdown() + + +if __name__ == '__main__': + main() diff --git a/ext/dapr-ext-workflow/pytest.ini b/ext/dapr-ext-workflow/pytest.ini new file mode 100644 index 00000000..d08eef80 --- /dev/null +++ b/ext/dapr-ext-workflow/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + e2e: marks tests as end-to-end integration tests (deselect with '-m "not e2e"') + diff --git a/ext/dapr-ext-workflow/setup.cfg b/ext/dapr-ext-workflow/setup.cfg index 6efe6668..ee15484a 100644 --- a/ext/dapr-ext-workflow/setup.cfg +++ b/ext/dapr-ext-workflow/setup.cfg @@ -20,12 +20,12 @@ project_urls = Source = https://github.com/dapr/python-sdk [options] -python_requires = >=3.9 +python_requires = >=3.10 packages = find_namespace: include_package_data = True install_requires = dapr >= 1.16.0.dev - durabletask-dapr >= 0.2.0a9 + durabletask-dapr >= 0.2.0a13 [options.packages.find] include = diff --git a/ext/dapr-ext-workflow/tests/README.md b/ext/dapr-ext-workflow/tests/README.md new file mode 100644 index 00000000..6759a362 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/README.md @@ -0,0 +1,94 @@ +## Workflow tests: unit, integration, and custom ports + +This directory contains unit tests (no sidecar required) and integration tests (require a running sidecar/runtime). + +### Prereqs + +- Python 3.11+ (tox will create an isolated venv) +- Dapr sidecar for integration tests (HTTP and gRPC ports) +- Optional: Durable Task gRPC endpoint for DT e2e tests + +### Run all tests via tox (recommended) + +```bash +tox -e py311 +``` + +This runs: +- Core SDK tests (unittest) +- Workflow extension unit tests (pytest) +- Workflow extension integration tests (pytest) if your sidecar/runtime is reachable + +### Run only workflow unit tests + +Unit tests live at `ext/dapr-ext-workflow/tests` excluding the `integration/` subfolder. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests -k "not integration" +``` + +### Run workflow integration tests + +Integration tests live under `ext/dapr-ext-workflow/tests/integration/` and require a running sidecar/runtime. + +With tox: +```bash +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Directly (outside tox): +```bash +pytest -q ext/dapr-ext-workflow/tests/integration +``` + +If tests cannot reach your sidecar/runtime, they will skip or fail fast depending on the specific test. + +### Configure custom sidecar ports/endpoints + +The SDK reads connection settings from env vars (see `dapr.conf.global_settings`). Use these to point tests at custom ports: + +- Dapr gRPC: + - `DAPR_GRPC_ENDPOINT` (preferred): endpoint string, e.g. `dns:127.0.0.1:50051` + - or `DAPR_RUNTIME_HOST` and `DAPR_GRPC_PORT`, e.g. `DAPR_RUNTIME_HOST=127.0.0.1`, `DAPR_GRPC_PORT=50051` + +- Dapr HTTP (only for HTTP-based tests): + - `DAPR_HTTP_ENDPOINT`: e.g. `http://127.0.0.1:3600` + - or `DAPR_RUNTIME_HOST` and `DAPR_HTTP_PORT`, e.g. `DAPR_HTTP_PORT=3600` + +Examples: +```bash +# Use custom gRPC 50051 and HTTP 3600 +export DAPR_GRPC_ENDPOINT=dns:127.0.0.1:50051 +export DAPR_HTTP_ENDPOINT=http://127.0.0.1:3600 + +# Alternatively, using host/port pairs +export DAPR_RUNTIME_HOST=127.0.0.1 +export DAPR_GRPC_PORT=50051 +export DAPR_HTTP_PORT=3600 + +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration +``` + +Note: For gRPC, avoid `http://` or `https://` schemes. Use `dns:host:port` or just set host/port separately. + +### Durable Task e2e tests (optional) + +Some tests (e.g., `integration/test_async_e2e_dt.py`) talk directly to a Durable Task gRPC endpoint. They use: + +- `DURABLETASK_GRPC_ENDPOINT` (default `localhost:56178`) + +If your DT runtime listens elsewhere: +```bash +export DURABLETASK_GRPC_ENDPOINT=127.0.0.1:56179 +tox -e py311 -- pytest -q ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py +``` + + + + diff --git a/ext/dapr-ext-workflow/tests/_fakes.py b/ext/dapr-ext-workflow/tests/_fakes.py new file mode 100644 index 00000000..09051702 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/_fakes.py @@ -0,0 +1,72 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + + +class FakeOrchestrationContext: + def __init__( + self, + *, + instance_id: str = 'wf-1', + current_utc_datetime: datetime | None = None, + is_replaying: bool = False, + workflow_name: str = 'wf', + parent_instance_id: str | None = None, + history_event_sequence: int | None = 1, + trace_parent: str | None = None, + trace_state: str | None = None, + orchestration_span_id: str | None = None, + workflow_attempt: int | None = None, + ) -> None: + self.instance_id = instance_id + self.current_utc_datetime = ( + current_utc_datetime if current_utc_datetime else datetime(2025, 1, 1) + ) + self.is_replaying = is_replaying + self.workflow_name = workflow_name + self.parent_instance_id = parent_instance_id + self.history_event_sequence = history_event_sequence + self.trace_parent = trace_parent + self.trace_state = trace_state + self.orchestration_span_id = orchestration_span_id + self.workflow_attempt = workflow_attempt + + +class FakeActivityContext: + def __init__( + self, + *, + orchestration_id: str = 'wf-1', + task_id: int = 1, + attempt: int | None = None, + trace_parent: str | None = None, + trace_state: str | None = None, + workflow_span_id: str | None = None, + ) -> None: + self.orchestration_id = orchestration_id + self.task_id = task_id + self.trace_parent = trace_parent + self.trace_state = trace_state + self.workflow_span_id = workflow_span_id + + +def make_orch_ctx(**overrides: Any) -> FakeOrchestrationContext: + return FakeOrchestrationContext(**overrides) + + +def make_act_ctx(**overrides: Any) -> FakeActivityContext: + return FakeActivityContext(**overrides) diff --git a/ext/dapr-ext-workflow/tests/conftest.py b/ext/dapr-ext-workflow/tests/conftest.py new file mode 100644 index 00000000..f20a225e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/conftest.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Ensure tests prefer the local python-sdk repository over any installed site-packages +# This helps when running pytest directly (outside tox/CI), so changes in the repo are exercised. +from __future__ import annotations # noqa: I001 + +import sys +from pathlib import Path +import importlib +import pytest + + +def pytest_configure(config): # noqa: D401 (pytest hook) + """Pytest configuration hook that prepends the repo root to sys.path. + + This ensures `import dapr` resolves to the local source tree when running tests directly. + Under tox/CI (editable installs), this is a no-op but still safe. + """ + try: + # ext/dapr-ext-workflow/tests/conftest.py -> repo root is 3 parents up + repo_root = Path(__file__).resolve().parents[3] + except Exception: + return + + repo_str = str(repo_root) + if repo_str not in sys.path: + sys.path.insert(0, repo_str) + + # Best-effort diagnostic: show where dapr was imported from + try: + dapr_mod = importlib.import_module('dapr') + dapr_path = Path(getattr(dapr_mod, '__file__', '')).resolve() + where = 'site-packages' if 'site-packages' in str(dapr_path) else 'local-repo' + print(f'[dapr-ext-workflow/tests] dapr resolved from {where}: {dapr_path}', file=sys.stderr) + except Exception: + # If dapr isn't importable yet, that's fine; tests importing it later will use modified sys.path + pass + + +@pytest.fixture(autouse=True) +def cleanup_workflow_registrations(request): + """Clean up workflow/activity registration markers after each test. + + This prevents test interference when the same function objects are reused across tests. + The workflow runtime marks functions with _dapr_alternate_name and _activity_registered + attributes, which can cause 'already registered' errors in subsequent tests. + """ + yield # Run the test + + # After test completes, clean up functions defined in the test module + test_module = sys.modules.get(request.module.__name__) + if test_module: + for name in dir(test_module): + obj = getattr(test_module, name, None) + if callable(obj) and hasattr(obj, '__dict__'): + try: + # Only clean up if __dict__ is writable (not mappingproxy) + if isinstance(obj.__dict__, dict): + obj.__dict__.pop('_dapr_alternate_name', None) + obj.__dict__.pop('_activity_registered', None) + except (AttributeError, TypeError): + # Skip objects with read-only __dict__ + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py new file mode 100644 index 00000000..5c9feca6 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_async_e2e_dt.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +""" +Async e2e tests using durabletask worker/client directly. + +These validate basic orchestration behavior against a running sidecar +to isolate environment issues from WorkflowRuntime wiring. +""" + +from __future__ import annotations + +import os +import time + +import pytest +from durabletask.aio import AsyncWorkflowContext +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + +pytestmark = pytest.mark.e2e + + +def _is_runtime_available(ep_str: str) -> bool: + import socket + + try: + host, port = ep_str.split(':') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, int(port))) + sock.close() + return result == 0 + except Exception: + return False + + +endpoint = os.getenv('DAPR_GRPC_ENDPOINT', 'localhost:50001') + +skip_if_no_runtime = pytest.mark.skipif( + not _is_runtime_available(endpoint), + reason='DurableTask runtime not available', +) + + +@skip_if_no_runtime +def test_dt_simple_activity_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, x: int) -> int: + return x * 3 + + worker.add_activity(act) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, x: int) -> int: + return await ctx.call_activity(act, input=x) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-act-{int(time.time() * 1000)}' + client.schedule_new_orchestration(orch, input=5, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + # Output is JSON serialized scalar + assert st.serialized_output.strip() in ('15', '"15"') + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_timer_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + @worker.add_async_orchestrator + async def orch(ctx: AsyncWorkflowContext, delay: float) -> dict: + start = ctx.now() + await ctx.sleep(delay) + end = ctx.now() + return {'start': start.isoformat(), 'end': end.isoformat(), 'delay': delay} + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-timer-{int(time.time() * 1000)}' + delay = 1.0 + client.schedule_new_orchestration(orch, input=delay, instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass + + +@skip_if_no_runtime +def test_dt_sub_orchestrator_e2e(): + # using global read-only endpoint variable + worker = TaskHubGrpcWorker(host_address=endpoint) + client = TaskHubGrpcClient(host_address=endpoint) + + def act(ctx, s: str) -> str: + return f'A:{s}' + + worker.add_activity(act) + + async def child(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] child start', s) + try: + res = await ctx.call_activity(act, input=s) + print('[E2E DEBUG] child done', res) + return res + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] child exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + # Explicit registration to avoid decorator replacing symbol with a string in newer versions + worker.add_async_orchestrator(child) + + async def parent(ctx: AsyncWorkflowContext, s: str) -> str: + print('[E2E DEBUG] parent start', s) + try: + c = await ctx.call_sub_orchestrator(child, input=s) + out = f'P:{c}' + print('[E2E DEBUG] parent done', out) + return out + except Exception as exc: # pragma: no cover - troubleshooting aid + import traceback as _tb + + print('[E2E DEBUG] parent exception:', type(exc).__name__, str(exc)) + print(_tb.format_exc()) + raise + + worker.add_async_orchestrator(parent) + + worker.start() + try: + try: + if hasattr(worker, 'wait_for_ready'): + worker.wait_for_ready(timeout=10) # type: ignore[attr-defined] + except Exception: + pass + iid = f'dt-e2e-sub-{int(time.time() * 1000)}' + print('[E2E DEBUG] scheduling instance', iid) + client.schedule_new_orchestration(parent, input='x', instance_id=iid) + st = client.wait_for_orchestration_completion(iid, timeout=30) + assert st is not None + if st.runtime_status.name != 'COMPLETED': + # Print orchestration state details to aid debugging + print('[E2E DEBUG] orchestration FAILED; details:') + to_json = getattr(st, 'to_json', None) + if callable(to_json): + try: + print(to_json()) + except Exception: + pass + print('status=', getattr(st, 'runtime_status', None)) + print('output=', getattr(st, 'serialized_output', None)) + print('failure=', getattr(st, 'failure_details', None)) + assert st.runtime_status.name == 'COMPLETED' + finally: + try: + worker.stop() + except Exception: + pass diff --git a/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py new file mode 100644 index 00000000..cafb125d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/integration/test_integration_async_semantics.py @@ -0,0 +1,965 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +import time + +import pytest +from dapr.ext.workflow import ( + AsyncWorkflowContext, + DaprWorkflowClient, + DaprWorkflowContext, + WorkflowRuntime, +) +from dapr.ext.workflow.interceptors import ( + BaseRuntimeInterceptor, + ExecuteActivityRequest, + ExecuteWorkflowRequest, +) + +pytestmark = pytest.mark.e2e + +skip_integration = pytest.mark.skipif( + False, + reason='integration enabled', +) + + +@skip_integration +def test_integration_suspension_and_buffering(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='suspend_orchestrator_async') + async def suspend_orchestrator(ctx: AsyncWorkflowContext): + # Expose suspension state via custom status + ctx.set_custom_status({'is_suspended': getattr(ctx, 'is_suspended', False)}) + # Wait for 'resume_event' and then complete + data = await ctx.wait_for_external_event('resume_event') + return {'resumed_with': data} + + runtime.start() + try: + # Allow connection to stabilize before scheduling + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'suspend-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=suspend_orchestrator, instance_id=instance_id) + + # Wait until started + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Pause and verify state becomes SUSPENDED and custom status updates on next activation + client.pause_workflow(instance_id) + # Give the worker time to process suspension + time.sleep(1) + state = client.get_workflow_state(instance_id) + assert state is not None + assert state.runtime_status.name in ( + 'SUSPENDED', + 'RUNNING', + ) # some hubs report SUSPENDED explicitly + + # While suspended, raise the event; it should buffer + client.raise_workflow_event(instance_id, 'resume_event', data={'ok': True}) + + # Resume and expect completion + client.resume_workflow(instance_id) + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_generator_metadata_propagation(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md_gen') + def recv_md_gen(ctx, _=None): + return ctx.get_metadata() or {} + + @runtime.workflow(name='gen_parent_sets_md') + def parent_gen(ctx: DaprWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'tier': 'gold'}) + md = yield ctx.call_activity(recv_md_gen, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'gen-md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_gen, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('tier') == 'gold' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + return { + 'tp': getattr(ctx, 'trace_parent', None), + 'ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + } + + @runtime.async_workflow(name='child_trace') + async def child(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_tp': getattr(ctx, 'trace_parent', None), + 'wf_ts': getattr(ctx, 'trace_state', None), + 'wf_span': getattr(ctx, 'workflow_span_id', None), + 'act': await ctx.call_activity(trace_probe, input=None), + } + + @runtime.async_workflow(name='parent_trace') + async def parent(ctx: AsyncWorkflowContext): + child_out = await ctx.call_child_workflow(child, input=None) + return { + 'parent_tp': getattr(ctx, 'trace_parent', None), + 'parent_span': getattr(ctx, 'workflow_span_id', None), + 'child': child_out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + + # TODO: assert more specifically when we have trace context information + + # Parent (engine-provided fields may be absent depending on runtime build/config) + assert isinstance(data.get('parent_tp'), (str, type(None))) + assert isinstance(data.get('parent_span'), (str, type(None))) + # Child orchestrator fields + _child = data.get('child') or {} + assert isinstance(_child.get('wf_tp'), (str, type(None))) + assert isinstance(_child.get('wf_span'), (str, type(None))) + # Activity fields under child + act = _child.get('act') or {} + assert isinstance(act.get('tp'), (str, type(None))) + assert isinstance(act.get('wf_span'), (str, type(None))) + + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_child_workflow_injected_metadata(): + # Deterministic trace propagation using interceptors via durable metadata + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseRuntimeInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_id' + + class InjectTraceClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class InjectTraceOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'sdk-trace-123') + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class RestoreTraceRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Ensure metadata arrives + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + assert isinstance((request.metadata or {}).get(TRACE_KEY), str) + return next(request) + + runtime = WorkflowRuntime( + runtime_interceptors=[RestoreTraceRuntime()], + workflow_outbound_interceptors=[InjectTraceOutbound()], + ) + + @runtime.activity(name='trace_probe2') + def trace_probe2(ctx, _=None): + return getattr(ctx, 'get_metadata', lambda: {})().get(TRACE_KEY) + + @runtime.async_workflow(name='child_trace2') + async def child2(ctx: AsyncWorkflowContext, _=None): + return { + 'wf_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'act_md': await ctx.call_activity(trace_probe2, input=None), + } + + @runtime.async_workflow(name='parent_trace2') + async def parent2(ctx: AsyncWorkflowContext): + out = await ctx.call_child_workflow(child2, input=None) + return { + 'parent_md': (ctx.get_metadata() or {}).get(TRACE_KEY), + 'child': out, + } + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient(interceptors=[InjectTraceClient()]) + iid = f'trace-child-md-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent2, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + data = _json.loads(st.to_json().get('serialized_output') or '{}') + assert data.get('parent_md') == 'sdk-trace-123' + child = data.get('child') or {} + assert child.get('wf_md') == 'sdk-trace-123' + assert child.get('act_md') == 'sdk-trace-123' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_termination_semantics(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='termination_orchestrator_async') + async def termination_orchestrator(ctx: AsyncWorkflowContext): + # Long timer; test will terminate before it fires + await ctx.create_timer(300.0) + return 'not-reached' + + print(list(runtime._WorkflowRuntime__worker._registry.orchestrators.keys())) + + runtime.start() + try: + time.sleep(3) + + client = DaprWorkflowClient() + instance_id = f'term-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=termination_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + + # Terminate and assert TERMINATED state, not raising inside orchestrator + client.terminate_workflow(instance_id, output='terminated') + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert final is not None + assert final.runtime_status.name == 'TERMINATED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_when_any_first_wins(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='when_any_async') + async def when_any_orchestrator(ctx: AsyncWorkflowContext): + first = await ctx.when_any( + [ + ctx.wait_for_external_event('go'), + ctx.create_timer(300.0), + ] + ) + # Return a simple, serializable value (winner's result) to avoid output serialization issues + try: + result = first.get_result() + except Exception: + result = None + return {'winner_result': result} + + runtime.start() + try: + # Ensure worker has established streams before scheduling + try: + if hasattr(runtime, 'wait_for_ready'): + runtime.wait_for_ready(timeout=15) # type: ignore[attr-defined] + except Exception: + pass + time.sleep(2) + + client = DaprWorkflowClient() + instance_id = f'whenany-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=when_any_orchestrator, instance_id=instance_id) + client.wait_for_workflow_start(instance_id, timeout_in_seconds=30) + # Confirm RUNNING state before raising event (mitigates race conditions) + try: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is None + or getattr(st, 'runtime_status', None) is None + or st.runtime_status.name != 'RUNNING' + ): + end = time.time() + 10 + while time.time() < end: + st = client.get_workflow_state(instance_id, fetch_payloads=False) + if ( + st is not None + and getattr(st, 'runtime_status', None) is not None + and st.runtime_status.name == 'RUNNING' + ): + break + time.sleep(0.2) + except Exception: + pass + + # Raise event immediately to win the when_any + client.raise_workflow_event(instance_id, 'go', data={'ok': True}) + + # Brief delay to allow event processing, then strictly use DaprWorkflowClient + time.sleep(1.0) + final = None + try: + final = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + except TimeoutError: + final = None + if final is None: + deadline = time.time() + 30 + while time.time() < deadline: + s = client.get_workflow_state(instance_id, fetch_payloads=False) + if s is not None and getattr(s, 'runtime_status', None) is not None: + if s.runtime_status.name in ('COMPLETED', 'FAILED', 'TERMINATED'): + final = s + break + time.sleep(0.5) + assert final is not None + assert final.runtime_status.name == 'COMPLETED' + # TODO: when sidecar exposes command diagnostics, assert only one command set was emitted + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_activity_completes(): + runtime = WorkflowRuntime() + + @runtime.activity(name='echo_int') + def echo_act(ctx, x: int) -> int: + return x + + @runtime.async_workflow(name='async_activity_once') + async def wf(ctx: AsyncWorkflowContext): + out = await ctx.call_activity(echo_act, input=7) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'act-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + if state.runtime_status.name != 'COMPLETED': + fd = getattr(state, 'failure_details', None) + msg = getattr(fd, 'message', None) if fd else None + et = getattr(fd, 'error_type', None) if fd else None + print(f'[INTEGRATION DEBUG] Failure details: {et} {msg}') + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_activity(): + runtime = WorkflowRuntime() + + @runtime.activity(name='recv_md') + def recv_md(ctx, _=None): + md = ctx.get_metadata() if hasattr(ctx, 'get_metadata') else {} + return md + + @runtime.async_workflow(name='wf_with_md') + async def wf(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme'}) + md = await ctx.call_activity(recv_md, input=None) + return md + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_metadata_outbound_to_child_workflow(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_recv_md') + async def child(ctx: AsyncWorkflowContext, _=None): + # Echo inbound metadata + return ctx.get_metadata() or {} + + @runtime.async_workflow(name='parent_sets_md') + async def parent(ctx: AsyncWorkflowContext): + ctx.set_metadata({'tenant': 'acme', 'role': 'user'}) + out = await ctx.call_child_workflow(child, input=None) + return out + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'md-child-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + # Validate output has metadata keys + data = state.to_json() + import json as _json + + out = _json.loads(data.get('serialized_output') or '{}') + assert out.get('tenant') == 'acme' + assert out.get('role') == 'user' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_trace_context_with_runtime_interceptors(): + """E2E: Verify trace_parent and orchestration_span_id via runtime interceptors.""" + records = { # captured by interceptor + 'wf_tp': None, + 'wf_span': None, + 'act_tp': None, + 'act_span': None, + } + + class TraceInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['wf_tp'] = getattr(ctx, 'trace_parent', None) + records['wf_span'] = getattr(ctx, 'workflow_span_id', None) + except Exception: + pass + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + ctx = request.ctx + try: + records['act_tp'] = getattr(ctx, 'trace_parent', None) + # Activity contexts don't have orchestration_span_id; capture task span if present + records['act_span'] = getattr(ctx, 'activity_span_id', None) + except Exception: + pass + return next(request) + + runtime = WorkflowRuntime(runtime_interceptors=[TraceInterceptor()]) + + @runtime.activity(name='trace_probe') + def trace_probe(ctx, _=None): + # Return trace context seen inside activity + return { + 'trace_parent': getattr(ctx, 'trace_parent', None), + 'trace_state': getattr(ctx, 'trace_state', None), + } + + @runtime.async_workflow(name='trace_parent_wf') + async def wf(ctx: AsyncWorkflowContext): + # Access orchestration span id and trace parent from workflow context + _ = getattr(ctx, 'workflow_span_id', None) + _ = getattr(ctx, 'trace_parent', None) + return await ctx.call_activity(trace_probe, input=None) + + runtime.start() + try: + time.sleep(3) + client = DaprWorkflowClient() + iid = f'trace-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + state = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert state is not None + assert state.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(state.to_json().get('serialized_output') or '{}') + # Activity returned strings (may be empty); assert types + assert isinstance(out.get('trace_parent'), (str, type(None))) + assert isinstance(out.get('trace_state'), (str, type(None))) + # Interceptor captured workflow and activity contexts + wf_tp = records['wf_tp'] + wf_span = records['wf_span'] + act_tp = records['act_tp'] + # TODO: assert more specifically when we have trace context information + assert isinstance(wf_tp, (str, type(None))) + assert isinstance(wf_span, (str, type(None))) + assert isinstance(act_tp, (str, type(None))) + # If we have a workflow span id, it should appear as parent-id inside activity traceparent + if isinstance(wf_span, str) and wf_span and isinstance(act_tp, str) and act_tp: + assert wf_span.lower() in act_tp.lower() + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_runtime_shutdown_is_clean(): + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='noop') + async def noop(ctx: AsyncWorkflowContext): + return 'ok' + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'shutdown-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=noop, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=30) + assert st is not None and st.runtime_status.name == 'COMPLETED' + finally: + # Call shutdown multiple times to ensure idempotent and clean behavior + for _ in range(3): + try: + runtime.shutdown() + except Exception: + # Test should not raise even if worker logs cancellation warnings + assert False, 'runtime.shutdown() raised unexpectedly' + # Recreate and shutdown again to ensure no lingering background threads break next startup + rt2 = WorkflowRuntime() + rt2.start() + try: + time.sleep(1) + finally: + try: + rt2.shutdown() + except Exception: + assert False, 'second runtime.shutdown() raised unexpectedly' + + +@skip_integration +def test_integration_continue_as_new_outbound_interceptor_metadata(): + # Verify continue_as_new outbound interceptor can inject metadata carried to the new run + from dapr.ext.workflow import BaseWorkflowOutboundInterceptor + + INJECT_KEY = 'injected' + + class InjectOnContinueAsNew(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(INJECT_KEY, 'yes') + request.metadata = md + return next(request) + + runtime = WorkflowRuntime( + workflow_outbound_interceptors=[InjectOnContinueAsNew()], + ) + + @runtime.workflow(name='continue_as_new_probe') + def wf(ctx, arg: dict | None = None): + if not arg or arg.get('phase') != 'second': + ctx.set_metadata({'tenant': 'acme'}) + # carry over existing metadata; interceptor will also inject + ctx.continue_as_new({'phase': 'second'}, carryover_metadata=True) + return # Must not yield after continue_as_new + # Second run: return inbound metadata observed + return ctx.get_metadata() or {} + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'can-int-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=wf, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Confirm both carried and injected metadata are present + assert out.get('tenant') == 'acme' + assert out.get(INJECT_KEY) == 'yes' + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_child_workflow_attempt_exposed(): + # Verify that child workflow ctx exposes workflow_attempt + runtime = WorkflowRuntime() + + @runtime.async_workflow(name='child_probe_attempt') + async def child_probe_attempt(ctx: AsyncWorkflowContext, _=None): + att = getattr(ctx, 'workflow_attempt', None) + return {'wf_attempt': att} + + @runtime.async_workflow(name='parent_calls_child_for_attempt') + async def parent_calls_child_for_attempt(ctx: AsyncWorkflowContext): + return await ctx.call_child_workflow(child_probe_attempt, input=None) + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient() + iid = f'child-attempt-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=parent_calls_child_for_attempt, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + val = out.get('wf_attempt', None) + assert (val is None) or isinstance(val, int) + finally: + runtime.shutdown() + + +@skip_integration +def test_integration_async_contextvars_trace_propagation(monkeypatch): + # Demonstrates contextvars-based trace propagation via interceptors in async workflows + import contextvars + import json as _json + + from dapr.ext.workflow import ( + BaseClientInterceptor, + BaseWorkflowOutboundInterceptor, + CallActivityRequest, + CallChildWorkflowRequest, + ScheduleWorkflowRequest, + ) + + TRACE_KEY = 'otel.trace_ctx' + current_trace: contextvars.ContextVar[str | None] = contextvars.ContextVar( + 'trace', default=None + ) + + class CVClient(BaseClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, 'wf-parent') + return next( + ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + ) + + class CVOutbound(BaseWorkflowOutboundInterceptor): + def call_activity(self, request: CallActivityRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallActivityRequest( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + def call_child_workflow(self, request: CallChildWorkflowRequest, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault(TRACE_KEY, current_trace.get()) + return next( + CallChildWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata=md, + ) + ) + + class CVRuntime(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + prev = current_trace.set((request.metadata or {}).get(TRACE_KEY)) + try: + return next(request) + finally: + current_trace.reset(prev) + + runtime = WorkflowRuntime( + runtime_interceptors=[CVRuntime()], workflow_outbound_interceptors=[CVOutbound()] + ) + + @runtime.activity(name='cv_probe') + def cv_probe(_ctx, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/act') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + flaky_call_count = [0] + + @runtime.activity(name='cv_flaky_probe') + def cv_flaky_probe(ctx, _=None): + before = current_trace.get() + flaky_call_count[0] += 1 + print(f'----------> flaky_call_count: {flaky_call_count[0]}') + if flaky_call_count[0] == 1: + # Fail first attempt to trigger retry + raise Exception('fail-once') + tok = current_trace.set(f'{before}/act-retry') if before else None + try: + inner = current_trace.get() + finally: + if tok is not None: + current_trace.reset(tok) + after = current_trace.get() + return {'before': before, 'inner': inner, 'after': after} + + @runtime.async_workflow(name='cv_child') + async def cv_child(ctx: AsyncWorkflowContext, _=None): + before = current_trace.get() + tok = current_trace.set(f'{before}/child') if before else None + try: + act = await ctx.call_activity(cv_probe, input=None) + finally: + if tok is not None: + current_trace.reset(tok) + restored = current_trace.get() + return {'before': before, 'restored': restored, 'act': act} + + @runtime.async_workflow(name='cv_parent') + async def cv_parent(ctx: AsyncWorkflowContext, _=None): + from datetime import timedelta + + from dapr.ext.workflow import RetryPolicy + + top_before = current_trace.get() + child = await ctx.call_child_workflow(cv_child, input=None) + after_child = current_trace.get() + act = await ctx.call_activity(cv_probe, input=None) + after_act = current_trace.get() + act_retry = await ctx.call_activity( + cv_flaky_probe, + input=None, + retry_policy=RetryPolicy( + first_retry_interval=timedelta(seconds=0), max_number_of_attempts=3 + ), + ) + return { + 'before': top_before, + 'child': child, + 'act': act, + 'act_retry': act_retry, + 'after_child': after_child, + 'after_act': after_act, + } + + runtime.start() + try: + time.sleep(2) + client = DaprWorkflowClient(interceptors=[CVClient()]) + iid = f'cv-ctx-{int(time.time() * 1000)}' + client.schedule_new_workflow(workflow=cv_parent, instance_id=iid) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None and st.runtime_status.name == 'COMPLETED' + out = _json.loads(st.to_json().get('serialized_output') or '{}') + # Top-level activity sees parent trace context during execution + act = out.get('act') or {} + assert act.get('before') == 'wf-parent' + assert act.get('inner') == 'wf-parent/act' + assert act.get('after') == 'wf-parent' + # Child workflow's activity at least inherits parent context + child = out.get('child') or {} + child_act = child.get('act') or {} + assert child_act.get('before') == 'wf-parent' + assert child_act.get('inner') == 'wf-parent/act' + assert child_act.get('after') == 'wf-parent' + # Flaky activity retried: second attempt succeeds and returns with parent context + act_retry = out.get('act_retry') or {} + assert act_retry.get('before') == 'wf-parent' + assert act_retry.get('inner') == 'wf-parent/act-retry' + assert act_retry.get('after') == 'wf-parent' + finally: + runtime.shutdown() + + +def test_runtime_interceptor_shapes_async_input(): + runtime = WorkflowRuntime() + + class ShapeInput(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + data = request.input + # Mutate input passed to workflow + if isinstance(data, dict): + shaped = {**data, 'shaped': True} + else: + shaped = {'value': data, 'shaped': True} + request.input = shaped + return next(request) + + # Recreate runtime with interceptor wired in + runtime = WorkflowRuntime(runtime_interceptors=[ShapeInput()]) + + @runtime.async_workflow(name='wf_shape_input') + async def wf_shape_input(ctx: AsyncWorkflowContext, arg: dict | None = None): + # Verify shaped input is observed by the workflow + return arg + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'shape-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_shape_input, instance_id=iid, input={'x': 1}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + import json as _json + + out = _json.loads(st.to_json().get('serialized_output') or '{}') + assert out.get('x') == 1 + assert out.get('shaped') is True + finally: + runtime.shutdown() + + +def test_runtime_interceptor_context_manager_with_async_workflow(): + """Test that context managers stay active during async workflow execution.""" + runtime = WorkflowRuntime() + + # Track when context enters and exits + context_state = {'entered': False, 'exited': False, 'workflow_ran': False} + + class ContextInterceptor(BaseRuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + # Wrapper generator to keep context manager alive + def wrapper(): + from contextlib import ExitStack + + with ExitStack(): + # Mark context as entered + context_state['entered'] = True + + # Get the workflow generator + gen = next(request) + + # Use yield from to keep context alive during execution + yield from gen + + # Context will exit after generator completes + context_state['exited'] = True + + return wrapper() + + runtime = WorkflowRuntime(runtime_interceptors=[ContextInterceptor()]) + + @runtime.async_workflow(name='wf_context_test') + async def wf_context_test(ctx: AsyncWorkflowContext, arg: dict | None = None): + context_state['workflow_ran'] = True + return {'result': 'ok'} + + runtime.start() + try: + from dapr.ext.workflow import DaprWorkflowClient + + client = DaprWorkflowClient() + iid = f'ctx-test-{id(runtime)}' + client.schedule_new_workflow(workflow=wf_context_test, instance_id=iid, input={}) + client.wait_for_workflow_start(iid, timeout_in_seconds=30) + st = client.wait_for_workflow_completion(iid, timeout_in_seconds=60) + assert st is not None + assert st.runtime_status.name == 'COMPLETED' + + # Verify context manager was active during workflow execution + assert context_state['entered'], 'Context should have been entered' + assert context_state['workflow_ran'], 'Workflow should have executed' + assert context_state['exited'], 'Context should have exited after completion' + finally: + runtime.shutdown() diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_registration.py b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py new file mode 100644 index 00000000..cbffdd5a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py @@ -0,0 +1,58 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.activities = {} + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_activity_decorator_supports_async(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + return x + 2 + + # Ensure registered + reg = rt._WorkflowRuntime__worker._registry + assert 'async_act' in reg.activities + + # Call the wrapper and ensure it runs the coroutine to completion + wrapper = reg.activities['async_act'] + + class _Ctx: + pass + + out = wrapper(_Ctx(), 5) + assert out == 7 diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py new file mode 100644 index 00000000..a02aeca7 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_retry_failure.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-act-retry' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf(ctx: AsyncWorkflowContext): + # One activity that ultimately fails after retries + await ctx.call_activity(lambda: None, retry_policy={'dummy': True}) + return 'not-reached' + + +def test_activity_retry_final_failure_raises(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime + next(gen) + # Simulate final failure after retry policy exhausts + with pytest.raises(RuntimeError, match='activity failed'): + gen.throw(RuntimeError('activity failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_api_coverage.py b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py new file mode 100644 index 00000000..c4bb28bc --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_api_coverage.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-cov' + self._status = None + + def set_custom_status(self, status): + self._status = status + + def continue_as_new(self, new_input, *, save_events=False): + self._continued = (new_input, save_events) + + # methods used by awaitables + def call_activity(self, activity, *, input=None, retry_policy=None): + class _T: + pass + + return _T() + + def call_child_workflow(self, workflow, *, input=None, instance_id=None, retry_policy=None): + class _T: + pass + + return _T() + + def create_timer(self, fire_at): + class _T: + pass + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + pass + + return _T() + + +def test_async_context_exposes_required_methods(): + base = FakeCtx() + ctx = AsyncWorkflowContext(base) + + # basic deterministic utils existence + assert isinstance(ctx.now(), datetime) + _ = ctx.random() + _ = ctx.uuid4() + + # pass-throughs + ctx.set_custom_status('ok') + assert base._status == 'ok' + ctx.continue_as_new({'foo': 1}, save_events=True) + assert getattr(base, '_continued', None) == ({'foo': 1}, True) + + # awaitable constructors do not raise + ctx.call_activity(lambda: None, input={'x': 1}) + ctx.call_child_workflow(lambda: None) + ctx.sleep(1.0) + ctx.wait_for_external_event('go') + ctx.when_all([]) + ctx.when_any([]) diff --git a/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py new file mode 100644 index 00000000..bfce7e2a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_concurrency_and_determinism.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask import task as durable_task_module +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'iid-123' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_first_wins(gen, winner_name): + # Simulate when_any: first send the winner, then finish + next(gen) # prime + result = gen.send({'task': winner_name}) + # the coroutine should complete; StopIteration will be raised by caller + return result + + +async def wf_when_all(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(1.0) + res = await ctx.when_all([a, b]) + return res + + +def test_when_all_maps_and_completes(monkeypatch): + # Patch durabletask.when_all to accept our FakeTask inputs and return a FakeTask + monkeypatch.setattr(durable_task_module, 'when_all', lambda tasks: FakeTask('when_all')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Drive two yields: when_all yields a task once; we simply return a list result + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send([{'task': 'activity:lambda'}, {'task': 'timer'}]) + except StopIteration as stop: + out = stop.value + assert isinstance(out, list) + assert len(out) == 2 + + +async def wf_when_any(ctx: AsyncWorkflowContext): + a = ctx.call_activity(lambda: None) + b = ctx.sleep(5.0) + first = await ctx.when_any([a, b]) + # Return the first result only; losers ignored deterministically + return first + + +def test_when_any_first_wins_behavior(monkeypatch): + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + try: + t = gen.send(None) + assert isinstance(t, FakeTask) + out = gen.send({'task': 'activity:lambda'}) + except StopIteration as stop: + out = stop.value + assert out == {'task': 'activity:lambda'} + + +def test_deterministic_random_and_uuid_are_stable(): + iid = 'iid-123' + now = datetime(2024, 1, 1) + rnd1 = deterministic_random(iid, now) + rnd2 = deterministic_random(iid, now) + seq1 = [rnd1.random() for _ in range(5)] + seq2 = [rnd2.random() for _ in range(5)] + assert seq1 == seq2 + u1 = deterministic_uuid4(deterministic_random(iid, now)) + u2 = deterministic_uuid4(deterministic_random(iid, now)) + assert u1 == u2 diff --git a/ext/dapr-ext-workflow/tests/test_async_context.py b/ext/dapr-ext-workflow/tests/test_async_context.py new file mode 100644 index 00000000..1ca0e57a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_context.py @@ -0,0 +1,202 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import types +from datetime import datetime, timedelta, timezone + +from dapr.ext.workflow import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext +from dapr.ext.workflow.workflow_context import WorkflowContext + + +class DummyBaseCtx: + def __init__(self): + self.instance_id = 'abc-123' + # freeze a deterministic timestamp + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + self._custom_status = None + self._continued = None + self._metadata = None + self._ei = types.SimpleNamespace( + workflow_id='abc-123', + workflow_name='wf', + is_replaying=False, + history_event_sequence=1, + inbound_metadata={'a': 'b'}, + parent_instance_id=None, + ) + + def set_custom_status(self, s: str): + self._custom_status = s + + def continue_as_new(self, new_input, *, save_events: bool = False): + self._continued = (new_input, save_events) + + # Metadata parity + def set_metadata(self, md): + self._metadata = md + + def get_metadata(self): + return self._metadata + + @property + def execution_info(self): + return self._ei + + +def test_parity_properties_and_now(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + assert ctx.instance_id == 'abc-123' + assert ctx.current_utc_datetime == datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + # now() should mirror current_utc_datetime + assert ctx.now() == ctx.current_utc_datetime + + +def test_timer_accepts_float_and_timedelta(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + # Float should be interpreted as seconds and produce a SleepAwaitable + aw1 = ctx.create_timer(1.5) + # Timedelta should pass through + aw2 = ctx.create_timer(timedelta(seconds=2)) + + # We only assert types by duck-typing public attribute presence to avoid + # importing internal classes in tests + assert hasattr(aw1, '_ctx') and hasattr(aw1, '__await__') + assert hasattr(aw2, '_ctx') and hasattr(aw2, '__await__') + + +def test_wait_for_external_event_and_concurrency_factories(): + ctx = AsyncWorkflowContext(DummyBaseCtx()) + + evt = ctx.wait_for_external_event('go') + assert hasattr(evt, '__await__') + + # when_all/when_any/gather return awaitables + a = ctx.create_timer(0.1) + b = ctx.create_timer(0.2) + + all_aw = ctx.when_all([a, b]) + any_aw = ctx.when_any([a, b]) + gat_aw = ctx.gather(a, b) + gat_exc_aw = ctx.gather(a, b, return_exceptions=True) + + for x in (all_aw, any_aw, gat_aw, gat_exc_aw): + assert hasattr(x, '__await__') + + +def test_deterministic_utils_and_passthroughs(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + + rnd = ctx.random() + # should behave like a random.Random-like object; test a stable first value + val = rnd.random() + # Just assert it is within (0,1) and stable across two calls to the seeded RNG instance + assert 0.0 < val < 1.0 + assert rnd.random() != val # next value changes + + uid = ctx.uuid4() + # Should be a UUID-like string representation + assert isinstance(str(uid), str) and len(str(uid)) >= 32 + + # passthroughs + ctx.set_custom_status('hello') + assert base._custom_status == 'hello' + + ctx.continue_as_new({'x': 1}, save_events=True) + assert base._continued == ({'x': 1}, True) + + +def test_async_metadata_api_and_execution_info(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + ctx.set_metadata({'k': 'v'}) + assert base._metadata == {'k': 'v'} + assert ctx.get_metadata() == {'k': 'v'} + ei = ctx.execution_info + assert ei and ei.workflow_id == 'abc-123' and ei.workflow_name == 'wf' + + +def test_async_outbound_metadata_plumbed_into_awaitables(): + base = DummyBaseCtx() + ctx = AsyncWorkflowContext(base) + a = ctx.call_activity(lambda: None, input=1, metadata={'m': 'n'}) + c = ctx.call_child_workflow(lambda c, x: None, input=2, metadata={'x': 'y'}) + # Introspect for test (internal attribute) + assert getattr(a, '_metadata', None) == {'m': 'n'} + assert getattr(c, '_metadata', None) == {'x': 'y'} + + +def test_async_parity_surface_exists(): + # Guard: ensure essential parity members exist + ctx = AsyncWorkflowContext(DummyBaseCtx()) + for name in ( + 'set_metadata', + 'get_metadata', + 'execution_info', + 'call_activity', + 'call_child_workflow', + 'continue_as_new', + ): + assert hasattr(ctx, name) + + +def test_public_api_parity_against_workflowcontext_abc(): + # Derive the required sync API surface from the ABC plus metadata/execution_info + required = { + name + for name, attr in WorkflowContext.__dict__.items() + if getattr(attr, '__isabstractmethod__', False) + } + required.update({'set_metadata', 'get_metadata', 'execution_info'}) + + # Async context must expose the same names + async_ctx = AsyncWorkflowContext(DummyBaseCtx()) + missing_in_async = [name for name in required if not hasattr(async_ctx, name)] + assert not missing_in_async, f'AsyncWorkflowContext missing: {missing_in_async}' + + # Sync context should also expose these names + class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'abc-123' + self.current_utc_datetime = datetime(2025, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + self.is_replaying = False + + def set_custom_status(self, s: str): + pass + + def create_timer(self, fire_at): + return object() + + def wait_for_external_event(self, name: str): + return object() + + def continue_as_new(self, new_input, *, save_events: bool = False): + pass + + def call_activity( + self, *, activity, input=None, retry_policy=None, app_id: str | None = None + ): + return object() + + def call_sub_orchestrator( + self, fn, *, input=None, instance_id=None, retry_policy=None, app_id: str | None = None + ): + return object() + + sync_ctx = DaprWorkflowContext(_FakeOrchCtx()) + missing_in_sync = [name for name in required if not hasattr(sync_ctx, name)] + assert not missing_in_sync, f'DaprWorkflowContext missing: {missing_in_sync}' diff --git a/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py new file mode 100644 index 00000000..9fa9735b --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_errors_and_backcompat.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeOrchestrationContext: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-errors' + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + +def drive_raise(gen, exc: Exception): + # Prime + task = gen.send(None) + assert isinstance(task, FakeTask) + # Simulate runtime failure of yielded task + try: + gen.throw(exc) + except StopIteration as stop: + return stop.value + + +async def wf_catches_activity_error(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: (_ for _ in ()).throw(RuntimeError('boom'))) + except RuntimeError as e: + return f'caught:{e}' + return 'not-reached' + + +def test_activity_error_propagates_into_coroutine_and_can_be_caught(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_catches_activity_error) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive_raise(gen, RuntimeError('boom')) + assert result == 'caught:boom' + + +async def wf_returns_sync(ctx: AsyncWorkflowContext): + return 42 + + +def test_sync_return_is_handled_without_runtime_error(): + fake = FakeOrchestrationContext() + runner = CoroutineOrchestratorRunner(wf_returns_sync) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and complete + try: + gen.send(None) + except StopIteration as stop: + assert stop.value == 42 + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + self.activities = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_generator_and_async_registration_coexist(monkeypatch): + # Monkeypatch TaskHubGrpcWorker to avoid real gRPC + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='gen_wf') + def gen(ctx): + yield ctx.create_timer(0) + return 'ok' + + async def async_wf(ctx: AsyncWorkflowContext): + await ctx.sleep(0) + return 'ok' + + rt.register_async_workflow(async_wf, name='async_wf') + + # Verify registry got both entries + reg = rt._WorkflowRuntime__worker._registry + assert 'gen_wf' in reg.orchestrators + assert 'async_wf' in reg.orchestrators + + # Drive generator orchestrator wrapper + gen_fn = reg.orchestrators['gen_wf'] + g = gen_fn(FakeOrchestrationContext()) + t = next(g) + assert isinstance(t, FakeTask) + try: + g.send(None) + except StopIteration as stop: + assert stop.value == 'ok' + + # Also verify CancelledError propagates and can be caught + import asyncio + + async def wf_cancel(ctx: AsyncWorkflowContext): + try: + await ctx.call_activity(lambda: None) + except asyncio.CancelledError: + return 'cancelled' + return 'not-reached' + + runner = CoroutineOrchestratorRunner(wf_cancel) + gen_2 = runner.to_generator(AsyncWorkflowContext(FakeOrchestrationContext()), None) + # prime + next(gen_2) + try: + gen_2.throw(asyncio.CancelledError()) + except StopIteration as stop: + assert stop.value == 'cancelled' diff --git a/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py new file mode 100644 index 00000000..f1df6720 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_registration_via_workflow.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +def test_workflow_decorator_detects_async_and_registers(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='async_wf') + async def async_wf(ctx: AsyncWorkflowContext, x: int) -> int: + # no awaits to keep simple + return x + 1 + + # ensure it was placed into registry + reg = rt._WorkflowRuntime__worker._registry + assert 'async_wf' in reg.orchestrators diff --git a/ext/dapr-ext-workflow/tests/test_async_replay.py b/ext/dapr-ext-workflow/tests/test_async_replay.py new file mode 100644 index 00000000..8fa33f12 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_replay.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime, timedelta + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self, instance_id: str = 'iid-replay', now: datetime | None = None): + self.current_utc_datetime = now or datetime(2024, 1, 1) + self.instance_id = instance_id + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}:{input}') + + def create_timer(self, fire_at): + return FakeTask(f'timer:{fire_at}') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_with_history(gen, results): + """Drive the generator with a pre-baked sequence of results, simulating replay history.""" + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +async def wf_mixed(ctx: AsyncWorkflowContext): + # activity + r1 = await ctx.call_activity(lambda: None, input={'x': 1}) + # timer + await ctx.sleep(timedelta(seconds=5)) + # event + e = await ctx.wait_for_external_event('go') + # deterministic utils + t = ctx.now() + u = str(ctx.uuid4()) + return {'a': r1, 'e': e, 't': t.isoformat(), 'u': u} + + +def test_replay_same_history_same_outputs(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_mixed) + # Pre-bake results sequence corresponding to activity -> timer -> event + history = [ + {'task': "activity:lambda:{'x': 1}"}, + None, + {'event': 42}, + ] + out1 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + out2 = drive_with_history(runner.to_generator(AsyncWorkflowContext(fake), None), history) + assert out1 == out2 diff --git a/ext/dapr-ext-workflow/tests/test_async_sandbox.py b/ext/dapr-ext-workflow/tests/test_async_sandbox.py new file mode 100644 index 00000000..62428e24 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sandbox.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import random +import time + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask.aio.errors import SandboxViolationError +from durabletask.aio.sandbox import SandboxMode + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'iid-sandbox' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +async def wf_sleep(ctx: AsyncWorkflowContext): + # asyncio.sleep should be patched to workflow timer + await asyncio.sleep(0.1) + return 'ok' + + +def drive(gen, first_result=None): + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + result = None + except StopIteration as stop: + return stop.value + + +def test_sandbox_best_effort_patches_sleep(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_sleep, sandbox_mode=SandboxMode.BEST_EFFORT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' + + +def test_sandbox_random_uuid_time_are_deterministic(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner( + lambda ctx: _wf_random_uuid_time(ctx), sandbox_mode=SandboxMode.BEST_EFFORT + ) + gen1 = runner.to_generator(AsyncWorkflowContext(fake), None) + out1 = drive(gen1) + gen2 = runner.to_generator(AsyncWorkflowContext(fake), None) + out2 = drive(gen2) + assert out1 == out2 + + +async def _wf_random_uuid_time(ctx: AsyncWorkflowContext): + r1 = random.random() + u1 = __import__('uuid').uuid4() + t1 = time.time(), getattr(time, 'time_ns', lambda: int(time.time() * 1_000_000_000))() + # no awaits needed; return tuple + return (r1, str(u1), t1[0], t1[1]) + + +def test_strict_blocks_create_task(): + async def wf(ctx: AsyncWorkflowContext): + with pytest.raises(SandboxViolationError): + asyncio.create_task(asyncio.sleep(0)) + return 'ok' + + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf, sandbox_mode=SandboxMode.STRICT) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen) + assert result == 'ok' diff --git a/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py new file mode 100644 index 00000000..ad38c50f --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_sub_orchestrator.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-sub' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive_success(gen, results): + try: + next(gen) + idx = 0 + while True: + gen.send(results[idx]) + idx += 1 + except StopIteration as stop: + return stop.value + + +def drive_raise(gen, exc: Exception): + # Prime + next(gen) + # Throw failure into orchestrator + return pytest.raises(Exception, gen.throw, exc) + + +async def child(ctx: AsyncWorkflowContext, n: int) -> int: + return n * 2 + + +async def parent_success(ctx: AsyncWorkflowContext): + res = await ctx.call_child_workflow(child, input=3) + return res + 1 + + +def test_sub_orchestrator_success(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_success) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # First yield is the sub-orchestrator task + result = drive_success(gen, results=[6]) + assert result == 7 + + +async def parent_failure(ctx: AsyncWorkflowContext): + # Do not catch; allow failure to propagate + await ctx.call_child_workflow(child, input=1) + return 'not-reached' + + +def test_sub_orchestrator_failure_raises_into_orchestrator(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(parent_failure) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + # Prime and then throw into the coroutine to simulate child failure + next(gen) + with pytest.raises(RuntimeError, match='child failed'): + gen.throw(RuntimeError('child failed')) diff --git a/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py new file mode 100644 index 00000000..b893dc3a --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_when_any_losers_policy.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from durabletask import task as durable_task_module + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + import datetime + + self.current_utc_datetime = datetime.datetime(2024, 1, 1) + self.instance_id = 'iid-any' + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask('activity') + + def create_timer(self, fire_at): + return FakeTask('timer') + + +async def wf_when_any(ctx: AsyncWorkflowContext): + # Two awaitables: an activity and a timer + a = ctx.call_activity(lambda: None) + b = ctx.sleep(10) + first = await ctx.when_any([a, b]) + return first + + +def test_when_any_yields_once_and_returns_first_result(monkeypatch): + # Patch durabletask.when_any to avoid requiring real durabletask.Task objects + monkeypatch.setattr(durable_task_module, 'when_any', lambda tasks: FakeTask('when_any')) + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(wf_when_any) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + + # Prime; expect a single composite yield + yielded = gen.send(None) + assert isinstance(yielded, FakeTask) + # Send the 'first' completion; generator should complete without yielding again + try: + gen.send({'task': 'activity'}) + raise AssertionError('generator should have completed') + except StopIteration as stop: + assert stop.value == {'task': 'activity'} diff --git a/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py new file mode 100644 index 00000000..b56f5af6 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_workflow_basic.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner + + +class FakeTask: + def __init__(self, name: str): + self.name = name + + +class FakeCtx: + def __init__(self): + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.instance_id = 'test-instance' + self._events: dict[str, list] = {} + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + return FakeTask(f'activity:{getattr(activity, "__name__", str(activity))}') + + def call_child_workflow( + self, workflow, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + return FakeTask(f'sub:{getattr(workflow, "__name__", str(workflow))}') + + def create_timer(self, fire_at): + return FakeTask('timer') + + def wait_for_external_event(self, name: str): + return FakeTask(f'event:{name}') + + +def drive(gen, first_result=None): + """Drive a generator produced by the async driver, emulating the runtime.""" + try: + task = gen.send(None) + assert isinstance(task, FakeTask) + result = first_result + while True: + task = gen.send(result) + assert isinstance(task, FakeTask) + # Provide a generic result for every yield + result = {'task': task.name} + except StopIteration as stop: + return stop.value + + +async def sample_activity(ctx: AsyncWorkflowContext): + return await ctx.call_activity(lambda: None) + + +def test_activity_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_activity) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'task': 'activity:lambda'}) + assert result == {'task': 'activity:lambda'} + + +async def sample_timer(ctx: AsyncWorkflowContext): + await ctx.create_timer(1.0) + return 'done' + + +def test_timer_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_timer) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result=None) + assert result == 'done' + + +async def sample_event(ctx: AsyncWorkflowContext): + data = await ctx.wait_for_external_event('go') + return ('event', data) + + +def test_event_awaitable_roundtrip(): + fake = FakeCtx() + runner = CoroutineOrchestratorRunner(sample_event) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + result = drive(gen, first_result={'hello': 'world'}) + assert result == ('event', {'hello': 'world'}) diff --git a/ext/dapr-ext-workflow/tests/test_deterministic.py b/ext/dapr-ext-workflow/tests/test_deterministic.py new file mode 100644 index 00000000..fa76f22f --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_deterministic.py @@ -0,0 +1,74 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import datetime as _dt + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext +from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext + +""" +Tests for deterministic helpers shared across workflow contexts. +""" + + +class _FakeBaseCtx: + def __init__(self, instance_id: str, dt: _dt.datetime): + self.instance_id = instance_id + self.current_utc_datetime = dt + + +def _fixed_dt(): + return _dt.datetime(2024, 1, 1) + + +def test_random_string_deterministic_across_instances_async(): + base = _FakeBaseCtx('iid-1', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + b_ctx = AsyncWorkflowContext(base) + a = a_ctx.random_string(16) + b = b_ctx.random_string(16) + assert a == b + + +def test_random_string_deterministic_across_context_types(): + base = _FakeBaseCtx('iid-2', _fixed_dt()) + a_ctx = AsyncWorkflowContext(base) + s1 = a_ctx.random_string(12) + + # Minimal fake orchestration context for DaprWorkflowContext + d_ctx = DaprWorkflowContext(base) + s2 = d_ctx.random_string(12) + assert s1 == s2 + + +def test_random_string_respects_alphabet(): + base = _FakeBaseCtx('iid-3', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + s = ctx.random_string(20, alphabet='abc') + assert set(s).issubset(set('abc')) + + +def test_random_string_length_and_edge_cases(): + base = _FakeBaseCtx('iid-4', _fixed_dt()) + ctx = AsyncWorkflowContext(base) + + assert ctx.random_string(0) == '' + + with pytest.raises(ValueError): + ctx.random_string(-1) + + with pytest.raises(ValueError): + ctx.random_string(5, alphabet='') diff --git a/ext/dapr-ext-workflow/tests/test_generic_serialization.py b/ext/dapr-ext-workflow/tests/test_generic_serialization.py new file mode 100644 index 00000000..0aeb0c84 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_generic_serialization.py @@ -0,0 +1,64 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from dataclasses import dataclass +from typing import Any + +from dapr.ext.workflow import ( + ActivityIOAdapter, + CanonicalSerializable, + ensure_canonical_json, + serialize_activity_input, + serialize_activity_output, + use_activity_adapter, +) + + +@dataclass +class _Point(CanonicalSerializable): + x: int + y: int + + def to_canonical_json(self, *, strict: bool = True) -> Any: + return {'x': self.x, 'y': self.y} + + +def test_ensure_canonical_json_on_custom_object(): + p = _Point(1, 2) + out = ensure_canonical_json(p, strict=True) + assert out == {'x': 1, 'y': 2} + + +class _IO(ActivityIOAdapter): + def serialize_input(self, input: Any, *, strict: bool = True) -> Any: + if isinstance(input, _Point): + return {'pt': [input.x, input.y]} + return ensure_canonical_json(input, strict=strict) + + def serialize_output(self, output: Any, *, strict: bool = True) -> Any: + return {'ok': ensure_canonical_json(output, strict=strict)} + + +def test_activity_adapter_decorator_customizes_io(): + _use = use_activity_adapter(_IO()) + + @_use + def act(obj): + return obj + + pt = _Point(3, 4) + inp = serialize_activity_input(act, pt, strict=True) + assert inp == {'pt': [3, 4]} + + out = serialize_activity_output(act, {'k': 'v'}, strict=True) + assert out == {'ok': {'k': 'v'}} diff --git a/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py new file mode 100644 index 00000000..bc3b28e0 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_inbound_interceptors.py @@ -0,0 +1,559 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from dapr.ext.workflow import ( + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + WorkflowRuntime, +) + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _TracingInterceptor(RuntimeInterceptor): + """Interceptor that injects and restores trace context.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'wf_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'wf_trace_cleanup:{tracing_data}') + + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Extract tracing from input + tracing_data = None + if isinstance(request.input, dict) and 'tracing' in request.input: + tracing_data = request.input['tracing'] + self.events.append(f'act_trace_restored:{tracing_data}') + + # Call next in chain + result = next(request) + + if tracing_data: + self.events.append(f'act_trace_cleanup:{tracing_data}') + + return result + + +class _LoggingInterceptor(RuntimeInterceptor): + """Interceptor that logs workflow and activity execution.""" + + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + self.events.append(f'{self.label}:wf_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:wf_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:wf_error:{type(e).__name__}') + raise + + def execute_activity(self, request: ExecuteActivityRequest, next): + self.events.append(f'{self.label}:act_start:{request.input!r}') + try: + result = next(request) + self.events.append(f'{self.label}:act_complete:{result!r}') + return result + except Exception as e: + self.events.append(f'{self.label}:act_error:{type(e).__name__}') + raise + + +class _ValidationInterceptor(RuntimeInterceptor): + """Interceptor that validates inputs and outputs.""" + + def __init__(self, events: list[str]): + self.events = events + + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('wf_validation_failed') + raise ValueError('Invalid workflow input') + + self.events.append('wf_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, dict) and result.get('invalid_output'): + self.events.append('wf_output_validation_failed') + raise ValueError('Invalid workflow output') + + self.events.append('wf_output_validation_passed') + return result + + def execute_activity(self, request: ExecuteActivityRequest, next): + # Validate input + if isinstance(request.input, dict) and request.input.get('invalid'): + self.events.append('act_validation_failed') + raise ValueError('Invalid activity input') + + self.events.append('act_validation_passed') + result = next(request) + + # Validate output + if isinstance(result, str) and 'invalid' in result: + self.events.append('act_output_validation_failed') + raise ValueError('Invalid activity output') + + self.events.append('act_output_validation_passed') + return result + + +def test_single_interceptor_workflow_execution(monkeypatch): + """Test single interceptor around workflow execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='simple') + def simple(ctx, x: int): + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['simple'] + result = orch(_make_orch_ctx(), 5) + + # For non-generator workflows, the result is returned directly + assert result == 10 + assert events == [ + 'log:wf_start:5', + 'log:wf_complete:10', + ] + + +def test_single_interceptor_activity_execution(monkeypatch): + """Test single interceptor around activity execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + result = act(_make_act_ctx(), 7) + + assert result == 14 + assert events == [ + 'log:act_start:7', + 'log:act_complete:14', + ] + + +def test_multiple_interceptors_execution_order(monkeypatch): + """Test multiple interceptors execute in correct order.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + outer_interceptor = _LoggingInterceptor(events, 'outer') + inner_interceptor = _LoggingInterceptor(events, 'inner') + + # First interceptor in list is outermost + rt = WorkflowRuntime(runtime_interceptors=[outer_interceptor, inner_interceptor]) + + @rt.workflow(name='ordered') + def ordered(ctx, x: int): + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['ordered'] + result = orch(_make_orch_ctx(), 3) + + assert result == 4 + # Outer interceptor enters first, exits last (stack semantics) + assert events == [ + 'outer:wf_start:3', + 'inner:wf_start:3', + 'inner:wf_complete:4', + 'outer:wf_complete:4', + ] + + +def test_tracing_interceptor_context_restoration(monkeypatch): + """Test tracing interceptor properly handles trace context.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + tracing_interceptor = _TracingInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[tracing_interceptor]) + + @rt.workflow(name='traced') + def traced(ctx, input_data): + # Workflow can access the trace context that was restored + return {'result': input_data.get('value', 0) * 2} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['traced'] + + # Input with tracing data + input_with_trace = {'value': 5, 'tracing': {'trace_id': 'abc123', 'span_id': 'def456'}} + + result = orch(_make_orch_ctx(), input_with_trace) + + assert result == {'result': 10} + assert events == [ + "wf_trace_restored:{'trace_id': 'abc123', 'span_id': 'def456'}", + "wf_trace_cleanup:{'trace_id': 'abc123', 'span_id': 'def456'}", + ] + + +def test_validation_interceptor_input_validation(monkeypatch): + """Test validation interceptor catches invalid inputs.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + validation_interceptor = _ValidationInterceptor(events) + rt = WorkflowRuntime(runtime_interceptors=[validation_interceptor]) + + @rt.workflow(name='validated') + def validated(ctx, input_data): + return {'result': 'ok'} + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['validated'] + + # Test valid input + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == {'result': 'ok'} + assert 'wf_validation_passed' in events + assert 'wf_output_validation_passed' in events + + # Test invalid input + events.clear() + + with pytest.raises(ValueError, match='Invalid workflow input'): + orch(_make_orch_ctx(), {'invalid': True}) + + assert 'wf_validation_failed' in events + + +def test_interceptor_error_handling_workflow(monkeypatch): + """Test interceptor properly handles workflow errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='error_wf') + def error_wf(ctx, x: int): + raise ValueError('workflow error') + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['error_wf'] + + with pytest.raises(ValueError, match='workflow error'): + orch(_make_orch_ctx(), 1) + + assert events == [ + 'log:wf_start:1', + 'log:wf_error:ValueError', + ] + + +def test_interceptor_error_handling_activity(monkeypatch): + """Test interceptor properly handles activity errors.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='error_act') + def error_act(ctx, x: int) -> int: + raise RuntimeError('activity error') + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['error_act'] + + with pytest.raises(RuntimeError, match='activity error'): + act(_make_act_ctx(), 5) + + assert events == [ + 'log:act_start:5', + 'log:act_error:RuntimeError', + ] + + +def test_async_workflow_with_interceptors(monkeypatch): + """Test interceptors work with async workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='async_wf') + async def async_wf(ctx, x: int): + # Simple async workflow + return x * 3 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['async_wf'] + gen_result = orch(_make_orch_ctx(), 4) + + # Async workflows return a generator that needs to be driven + with pytest.raises(StopIteration) as stop: + next(gen_result) + result = stop.value.value + + assert result == 12 + # The interceptor sees the generator being returned, not the final result + assert events[0] == 'log:wf_start:4' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_async_activity_with_interceptors(monkeypatch): + """Test interceptors work with async activities.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.activity(name='async_act') + async def async_act(ctx, x: int) -> int: + await asyncio.sleep(0) # Simulate async work + return x * 4 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['async_act'] + result = act(_make_act_ctx(), 3) + + assert result == 12 + assert events == [ + 'log:act_start:3', + 'log:act_complete:12', + ] + + +def test_generator_workflow_with_interceptors(monkeypatch): + """Test interceptors work with generator workflows.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + logging_interceptor = _LoggingInterceptor(events, 'log') + rt = WorkflowRuntime(runtime_interceptors=[logging_interceptor]) + + @rt.workflow(name='gen_wf') + def gen_wf(ctx, x: int): + v1 = yield 'step1' + v2 = yield 'step2' + return (x, v1, v2) + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen_wf'] + gen_orch = orch(_make_orch_ctx(), 1) + + # Drive the generator + assert next(gen_orch) == 'step1' + assert gen_orch.send('result1') == 'step2' + with pytest.raises(StopIteration) as stop: + gen_orch.send('result2') + result = stop.value.value + + assert result == (1, 'result1', 'result2') + # For generator workflows, interceptor sees the generator being returned + assert events[0] == 'log:wf_start:1' + assert 'log:wf_complete:' in events[1] # The generator object is logged + + +def test_interceptor_chain_with_early_return(monkeypatch): + """Test interceptor can modify or short-circuit execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ShortCircuitInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + events.append('short_circuit_check') + if isinstance(request.input, dict) and request.input.get('short_circuit'): + events.append('short_circuited') + return 'short_circuit_result' + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + logging_interceptor = _LoggingInterceptor(events, 'log') + short_circuit_interceptor = _ShortCircuitInterceptor() + + rt = WorkflowRuntime(runtime_interceptors=[short_circuit_interceptor, logging_interceptor]) + + @rt.workflow(name='maybe_short') + def maybe_short(ctx, input_data): + return 'normal_result' + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['maybe_short'] + + # Test normal execution + result = orch(_make_orch_ctx(), {'value': 5}) + + assert result == 'normal_result' + assert 'short_circuit_check' in events + assert 'log:wf_start' in str(events) + assert 'log:wf_complete' in str(events) + + # Test short-circuit execution + events.clear() + result = orch(_make_orch_ctx(), {'short_circuit': True}) + + assert result == 'short_circuit_result' + assert 'short_circuit_check' in events + assert 'short_circuited' in events + # Logging interceptor should not be called when short-circuited + assert 'log:wf_start' not in str(events) + + +def test_interceptor_input_transformation(monkeypatch): + """Test interceptor can transform inputs before execution.""" + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _TransformInterceptor(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): + # Transform input by adding metadata + if isinstance(request.input, dict): + transformed_input = {**request.input, 'interceptor_metadata': 'added'} + new_input = ExecuteWorkflowRequest(ctx=request.ctx, input=transformed_input) + events.append(f'transformed_input:{transformed_input}') + return next(new_input) + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): + return next(request) + + transform_interceptor = _TransformInterceptor() + rt = WorkflowRuntime(runtime_interceptors=[transform_interceptor]) + + @rt.workflow(name='transform_test') + def transform_test(ctx, input_data): + # Workflow should see the transformed input + return input_data + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['transform_test'] + result = orch(_make_orch_ctx(), {'original': 'value'}) + + # Result should include the interceptor metadata + assert result == {'original': 'value', 'interceptor_metadata': 'added'} + assert 'transformed_input:' in str(events) + + +def test_runtime_interceptor_can_shape_activity_result(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _ShapeResult(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + res = next(request) + return {'wrapped': res} + + rt = WorkflowRuntime(runtime_interceptors=[_ShapeResult()]) + + @rt.activity(name='echo') + def echo(_ctx, x): + return x + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['echo'] + out = act(_make_act_ctx(), 7) + assert out == {'wrapped': 7} diff --git a/ext/dapr-ext-workflow/tests/test_interceptors.py b/ext/dapr-ext-workflow/tests/test_interceptors.py new file mode 100644 index 00000000..9ba37287 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_interceptors.py @@ -0,0 +1,176 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from dapr.ext.workflow import RuntimeInterceptor, WorkflowRuntime + +from ._fakes import make_act_ctx as _make_act_ctx +from ._fakes import make_orch_ctx as _make_orch_ctx + +""" +Comprehensive inbound interceptor tests for Dapr WorkflowRuntime. + +Tests the current interceptor system for runtime-side workflow and activity execution. +""" + + +""" +Runtime interceptor chain tests for `WorkflowRuntime`. + +This suite intentionally uses a fake worker/registry to validate interceptor composition +without requiring a sidecar. It focuses on the "why" behind runtime interceptors: + +- Ensure `execute_workflow` and `execute_activity` hooks compose in order and are + invoked exactly once around workflow entry/activity execution. +- Cover both generator-based and async workflows, asserting the chain returns a + generator to the runtime (rather than iterating it), preserving send()/throw() + semantics during orchestration replay. +- Keep signal-to-noise high for failures in chain logic independent of gRPC/sidecar. + +These tests complement outbound/client interceptor tests and e2e tests by providing +fast, deterministic coverage of the chaining behavior and generator handling rules. +""" + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _RecorderInterceptor(RuntimeInterceptor): + def __init__(self, events: list[str], label: str): + self.events = events + self.label = label + + def execute_workflow(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:wf_enter:{request.input!r}') + ret = next(request) + self.events.append(f'{self.label}:wf_ret_type:{ret.__class__.__name__}') + return ret + + def execute_activity(self, request, next): # type: ignore[override] + self.events.append(f'{self.label}:act_enter:{request.input!r}') + res = next(request) + self.events.append(f'{self.label}:act_exit:{res!r}') + return res + + +def test_generator_workflow_hooks_sequence(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='gen') + def gen(ctx, x: int): + v = yield 'A' + v2 = yield 'B' + return (x, v, v2) + + # Drive the registered orchestrator + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['gen'] + gen_driver = orch(_make_orch_ctx(), 10) + # Prime and run + assert next(gen_driver) == 'A' + assert gen_driver.send('ra') == 'B' + with pytest.raises(StopIteration) as stop: + gen_driver.send('rb') + result = stop.value.value + + assert result == (10, 'ra', 'rb') + # Interceptors run once around the workflow entry; they return a generator to the runtime + assert events[0] == 'mw:wf_enter:10' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_async_workflow_hooks_called(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + ic = _RecorderInterceptor(events, 'mw') + rt = WorkflowRuntime(runtime_interceptors=[ic]) + + @rt.workflow(name='awf') + async def awf(ctx, x: int): + # No awaits to keep the driver simple + return x + 1 + + reg = rt._WorkflowRuntime__worker._registry + orch = reg.orchestrators['awf'] + gen_orch = orch(_make_orch_ctx(), 41) + with pytest.raises(StopIteration) as stop: + next(gen_orch) + result = stop.value.value + + assert result == 42 + # For async workflow, interceptor sees entry and a generator return type + assert events[0] == 'mw:wf_enter:41' + assert events[1].startswith('mw:wf_ret_type:') + + +def test_activity_hooks_and_policy(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + events: list[str] = [] + + class _ExplodingActivity(RuntimeInterceptor): + def execute_activity(self, request, next): # type: ignore[override] + raise RuntimeError('boom') + + def execute_workflow(self, request, next): # type: ignore[override] + return next(request) + + # Continue-on-error policy + rt = WorkflowRuntime( + runtime_interceptors=[_RecorderInterceptor(events, 'mw'), _ExplodingActivity()] + ) + + @rt.activity(name='double') + def double(ctx, x: int) -> int: + return x * 2 + + reg = rt._WorkflowRuntime__worker._registry + act = reg.activities['double'] + # Error in interceptor bubbles up + with pytest.raises(RuntimeError): + act(_make_act_ctx(), 5) diff --git a/ext/dapr-ext-workflow/tests/test_metadata_context.py b/ext/dapr-ext-workflow/tests/test_metadata_context.py new file mode 100644 index 00000000..8a1f86fd --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_metadata_context.py @@ -0,0 +1,371 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Optional + +import pytest +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + ExecuteActivityRequest, + ExecuteWorkflowRequest, + RuntimeInterceptor, + ScheduleWorkflowRequest, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = datetime(2024, 1, 1) + self._custom_status = None + self.is_replaying = False + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + +def _drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +def test_client_schedule_metadata_envelope(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, + name, + *, + input=None, + instance_id=None, + start_at: Optional[datetime] = None, + reuse_id_policy=None, + ): # noqa: E501 + captured['name'] = name + captured['input'] = input + captured['instance_id'] = instance_id + captured['start_at'] = start_at + captured['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _InjectMetadata(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + # Add metadata without touching args + md = {'otel.trace_id': 't-123'} + new_request = ScheduleWorkflowRequest( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + metadata=md, + ) + return next(new_request) + + client = DaprWorkflowClient(interceptors=[_InjectMetadata()]) + + def wf(ctx, x): + yield 'noop' + + wf.__name__ = 'meta_wf' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + env = captured['input'] + assert isinstance(env, dict) + assert '__dapr_meta__' in env and '__dapr_payload__' in env + assert env['__dapr_payload__'] == {'a': 1} + assert env['__dapr_meta__']['metadata']['otel.trace_id'] == 't-123' + + +def test_runtime_inbound_unwrap_and_metadata_visible(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _Recorder(RuntimeInterceptor): + def execute_workflow(self, request: ExecuteWorkflowRequest, next): # type: ignore[override] + seen['metadata'] = request.metadata + return next(request) + + def execute_activity(self, request: ExecuteActivityRequest, next): # type: ignore[override] + seen['act_metadata'] = request.metadata + return next(request) + + rt = WorkflowRuntime(runtime_interceptors=[_Recorder()]) + + @rt.workflow(name='unwrap') + def unwrap(ctx, x): + # x should be the original payload, not the envelope + assert x == {'hello': 'world'} + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['unwrap'] + envelope = { + '__dapr_meta__': {'v': 1, 'metadata': {'c': 'd'}}, + '__dapr_payload__': {'hello': 'world'}, + } + result = orch(_FakeOrchCtx(), envelope) + assert result == 'ok' + assert seen['metadata'] == {'c': 'd'} + + +def test_outbound_activity_and_child_wrap_metadata(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _AddActMeta(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + # Wrap returned args with metadata by returning a new CallActivityRequest + return next( + type(request)( + activity_name=request.activity_name, + input=request.input, + retry_policy=request.retry_policy, + workflow_ctx=request.workflow_ctx, + metadata={'k': 'v'}, + ) + ) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input=request.input, + instance_id=request.instance_id, + workflow_ctx=request.workflow_ctx, + metadata={'p': 'q'}, + ) + ) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_AddActMeta()]) + + @rt.workflow(name='parent') + def parent(ctx, x): + a = yield ctx.call_activity(lambda: None, input={'i': 1}) + b = yield ctx.call_child_workflow(lambda c, y: None, input={'j': 2}) + # Return both so we can assert envelopes surfaced through our fake driver + return a, b + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + # First yield: activity token received by driver; shape may be envelope or raw depending on adapter + t1 = gen.send(None) + assert hasattr(t1, '_v') + # Resume with any value; our fake driver ignores and loops + t2 = gen.send({'act': 'done'}) + assert hasattr(t2, '_v') + with pytest.raises(StopIteration) as stop: + gen.send({'child': 'done'}) + result = stop.value.value + # The result is whatever user returned; envelopes validated above + assert isinstance(result, tuple) and len(result) == 2 + + +def test_context_set_metadata_default_propagation(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + # No outbound interceptor needed; runtime will wrap using ctx.get_metadata() + rt = WorkflowRuntime() + + @rt.workflow(name='use_ctx_md') + def use_ctx_md(ctx, x): + # Set default metadata on context + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}) + # Return the raw yielded value for assertion + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['use_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + assert hasattr(yielded, '_v') + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'ctx' + + +def test_per_call_metadata_overrides_context(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + @rt.workflow(name='override_ctx_md') + def override_ctx_md(ctx, x): + ctx.set_metadata({'k': 'ctx'}) + env = yield ctx.call_activity(lambda: None, input={'p': 1}, metadata={'k': 'per'}) + return env + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['override_ctx_md'] + gen = orch(_FakeOrchCtx(), 0) + yielded = gen.send(None) + env = yielded._v + assert isinstance(env, dict) + assert env.get('__dapr_meta__', {}).get('metadata', {}).get('k') == 'per' + + +def test_execution_info_workflow_and_activity(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime() + + def act(ctx, x): + # activity inbound metadata and execution info available + md = ctx.get_metadata() + ei = ctx.execution_info + assert md == {'m': 'v'} + assert ei is not None and ei.inbound_metadata == {'m': 'v'} + # activity_name should reflect the registered name + assert ei.activity_name == 'act' + return x + + @rt.workflow(name='execinfo') + def execinfo(ctx, x): + # set default metadata + ctx.set_metadata({'m': 'v'}) + # workflow execution info available (minimal inbound only) + wi = ctx.execution_info + assert wi is not None and wi.inbound_metadata == {} + v = yield ctx.call_activity(act, input=42) + return v + + # register activity + rt.activity(name='act')(act) + orch = rt._WorkflowRuntime__worker._registry.orchestrators['execinfo'] + gen = orch(_FakeOrchCtx(), 7) + # drive one yield (call_activity) + gen.send(None) + # send back a value for activity result + with pytest.raises(StopIteration) as stop: + gen.send(42) + assert stop.value.value == 42 + + +def test_client_interceptor_can_shape_schedule_response(monkeypatch): + import durabletask.client as client_mod + + captured: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + captured['name'] = name + return 'raw-id-123' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _ShapeId(ClientInterceptor): + def schedule_new_workflow(self, request: ScheduleWorkflowRequest, next): # type: ignore[override] + rid = next(request) + return f'shaped:{rid}' + + client = DaprWorkflowClient(interceptors=[_ShapeId()]) + + def wf(ctx): + yield 'noop' + + wf.__name__ = 'shape_test' + iid = client.schedule_new_workflow(wf, input=None) + assert iid == 'shaped:raw-id-123' diff --git a/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py new file mode 100644 index 00000000..09e67a8d --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_outbound_interceptors.py @@ -0,0 +1,203 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from dapr.ext.workflow import ( + BaseWorkflowOutboundInterceptor, + WorkflowOutboundInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'id' + self.current_utc_datetime = __import__('datetime').datetime(2024, 1, 1) + self.is_replaying = False + self._custom_status = None + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + self._continued_payload = None + self.workflow_attempt = None + + def call_activity(self, activity, *, input=None, retry_policy=None, app_id=None): + # return input back for assertion through driver + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def call_sub_orchestrator( + self, wf, *, input=None, instance_id=None, retry_policy=None, app_id=None + ): + class _T: + def __init__(self, v): + self._v = v + + return _T(input) + + def set_custom_status(self, custom_status): + self._custom_status = custom_status + + def create_timer(self, fire_at): + class _T: + def __init__(self, v): + self._v = v + + return _T(fire_at) + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self, v): + self._v = v + + return _T(name) + + def continue_as_new(self, new_request, *, save_events: bool = False): + # Record payload for assertions + self._continued_payload = new_request + + +def drive(gen, returned): + try: + t = gen.send(None) + assert hasattr(t, '_v') + res = returned + while True: + t = gen.send(res) + assert hasattr(t, '_v') + except StopIteration as stop: + return stop.value + + +class _InjectTrace(WorkflowOutboundInterceptor): + def call_activity(self, request, next): # type: ignore[override] + x = request.input + if x is None: + request = type(request)( + activity_name=request.activity_name, + input={'tracing': 'T'}, + retry_policy=request.retry_policy, + ) + elif isinstance(x, dict): + out = dict(x) + out.setdefault('tracing', 'T') + request = type(request)( + activity_name=request.activity_name, input=out, retry_policy=request.retry_policy + ) + return next(request) + + def call_child_workflow(self, request, next): # type: ignore[override] + return next( + type(request)( + workflow_name=request.workflow_name, + input={'child': request.input}, + instance_id=request.instance_id, + ) + ) + + +def test_outbound_activity_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + @rt.workflow(name='w') + def w(ctx, x): + # schedule an activity; runtime should pass transformed input to durable task + y = yield ctx.call_activity(lambda: None, input={'a': 1}) + return y['tracing'] + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'tracing': 'T', 'a': 1}) + assert out == 'T' + + +def test_outbound_child_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectTrace()]) + + def child(ctx, x): + yield 'noop' + + @rt.workflow(name='parent') + def parent(ctx, x): + y = yield ctx.call_child_workflow(child, input={'b': 2}) + return y + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['parent'] + gen = orch(_FakeOrchCtx(), 0) + out = drive(gen, returned={'child': {'b': 2}}) + assert out == {'child': {'b': 2}} + + +def test_outbound_continue_as_new_injection(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + class _InjectCAN(BaseWorkflowOutboundInterceptor): + def continue_as_new(self, request, next): # type: ignore[override] + md = dict(request.metadata or {}) + md.setdefault('x', '1') + request.metadata = md + return next(request) + + rt = WorkflowRuntime(workflow_outbound_interceptors=[_InjectCAN()]) + + @rt.workflow(name='w2') + def w2(ctx, x): + ctx.continue_as_new({'p': 1}) + return 'unreached' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w2'] + fake = _FakeOrchCtx() + _ = orch(fake, 0) + # Verify envelope contains injected metadata + assert isinstance(fake._continued_payload, dict) + meta = fake._continued_payload.get('__dapr_meta__') + payload = fake._continued_payload.get('__dapr_payload__') + assert isinstance(meta, dict) and isinstance(payload, dict) + assert meta.get('metadata', {}).get('x') == '1' + assert payload == {'p': 1} diff --git a/ext/dapr-ext-workflow/tests/test_sandbox_gather.py b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py new file mode 100644 index 00000000..7bb509dd --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_sandbox_gather.py @@ -0,0 +1,169 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta + +import pytest +from dapr.ext.workflow.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner +from dapr.ext.workflow.aio.sandbox import sandbox_scope +from durabletask.aio.sandbox import SandboxMode + +""" +Tests for sandboxed asyncio.gather behavior in async orchestrators. +""" + + +class _FakeCtx: + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1) + self.instance_id = 'test-instance' + + def create_timer(self, fire_at): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + def wait_for_external_event(self, name: str): + class _T: + def __init__(self): + self._parent = None + self.is_complete = False + + return _T() + + +def drive(gen, results): + try: + gen.send(None) + i = 0 + while True: + gen.send(results[i]) + i += 1 + except StopIteration as stop: + return stop.value + + +async def _plain(value): + return value + + +async def awf_empty(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + out = await asyncio.gather() + return out + + +def test_sandbox_gather_empty_returns_list(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_empty) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == [] + + +async def awf_when_all(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + b = ctx.wait_for_external_event('x') + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, b) + return res + + +def test_sandbox_gather_all_workflow_maps_to_when_all(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_when_all) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[[1, 2]]) + assert out == [1, 2] + + +async def awf_mixed(ctx: AsyncWorkflowContext): + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _plain('ok')) + return res + + +def test_sandbox_gather_mixed_returns_sequential_results(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_mixed) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[123]) + assert out == [123, 'ok'] + + +async def awf_return_exceptions(ctx: AsyncWorkflowContext): + async def _boom(): + raise RuntimeError('x') + + a = ctx.create_timer(timedelta(seconds=0)) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + res = await asyncio.gather(a, _boom(), return_exceptions=True) + return res + + +def test_sandbox_gather_return_exceptions(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_return_exceptions) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[321]) + assert isinstance(out[1], RuntimeError) + + +async def awf_multi_await(ctx: AsyncWorkflowContext): + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + g = asyncio.gather() + a = await g + b = await g + return (a, b) + + +def test_sandbox_gather_multi_await_safe(): + fake = _FakeCtx() + runner = CoroutineOrchestratorRunner(awf_multi_await) + gen = runner.to_generator(AsyncWorkflowContext(fake), None) + out = drive(gen, results=[None]) + assert out == ([], []) + + +def test_sandbox_gather_restored_outside(): + import asyncio as aio + + original = aio.gather + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.BEST_EFFORT): + pass + # After exit, gather should be restored + assert aio.gather is original + + +def test_strict_mode_blocks_create_task(): + import asyncio as aio + + fake = _FakeCtx() + ctx = AsyncWorkflowContext(fake) + with sandbox_scope(ctx, SandboxMode.STRICT): + if hasattr(aio, 'create_task'): + with pytest.raises(RuntimeError): + # Use a dummy coroutine to trigger the block + async def _c(): + return 1 + + aio.create_task(_c()) diff --git a/ext/dapr-ext-workflow/tests/test_trace_fields.py b/ext/dapr-ext-workflow/tests/test_trace_fields.py new file mode 100644 index 00000000..03d38e1e --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_trace_fields.py @@ -0,0 +1,60 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from dapr.ext.workflow.execution_info import ActivityExecutionInfo, WorkflowExecutionInfo +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext + + +class _FakeOrchCtx: + def __init__(self): + self.instance_id = 'wf-123' + self.current_utc_datetime = datetime(2025, 1, 1, tzinfo=timezone.utc) + self.is_replaying = False + self.workflow_name = 'wf_name' + self.parent_instance_id = 'parent-1' + self.history_event_sequence = 42 + self.trace_parent = '00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01' + self.trace_state = 'vendor=state' + self.orchestration_span_id = 'bbbbbbbbbbbbbbbb' + + +class _FakeActivityCtx: + def __init__(self): + self.orchestration_id = 'wf-123' + self.task_id = 7 + self.trace_parent = '00-cccccccccccccccccccccccccccccccc-dddddddddddddddd-01' + self.trace_state = 'v=1' + + +def test_workflow_execution_info_minimal(): + ei = WorkflowExecutionInfo(inbound_metadata={'k': 'v'}) + assert ei.inbound_metadata == {'k': 'v'} + + +def test_activity_execution_info_minimal(): + aei = ActivityExecutionInfo(inbound_metadata={'m': 'v'}, activity_name='act_name') + assert aei.inbound_metadata == {'m': 'v'} + + +def test_workflow_activity_context_execution_info_trace_fields(): + base = _FakeActivityCtx() + actx = WorkflowActivityContext(base) + aei = ActivityExecutionInfo(inbound_metadata={}, activity_name='act_name') + actx._set_execution_info(aei) + got = actx.execution_info + assert got is not None + assert got.inbound_metadata == {} diff --git a/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py new file mode 100644 index 00000000..35f93361 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_tracing_interceptors.py @@ -0,0 +1,171 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from dapr.ext.workflow import ( + ClientInterceptor, + DaprWorkflowClient, + RuntimeInterceptor, + WorkflowRuntime, +) + + +class _FakeRegistry: + def __init__(self): + self.orchestrators: dict[str, Any] = {} + self.activities: dict[str, Any] = {} + + def add_named_orchestrator(self, name, fn): + self.orchestrators[name] = fn + + def add_named_activity(self, name, fn): + self.activities[name] = fn + + +class _FakeWorker: + def __init__(self, *args, **kwargs): + self._registry = _FakeRegistry() + + def start(self): + pass + + def stop(self): + pass + + +class _FakeOrchestrationContext: + def __init__(self, *, is_replaying: bool = False): + self.instance_id = 'wf-1' + self.current_utc_datetime = datetime(2025, 1, 1) + self.is_replaying = is_replaying + self.workflow_name = 'wf' + self.parent_instance_id = None + self.history_event_sequence = 1 + self.trace_parent = None + self.trace_state = None + self.orchestration_span_id = None + + +def _drive_generator(gen, returned_value): + # Prime to first yield; then drive + next(gen) + while True: + try: + gen.send(returned_value) + except StopIteration as stop: + return stop.value + + +def test_client_injects_tracing_on_schedule(monkeypatch): + import durabletask.client as client_mod + + # monkeypatch TaskHubGrpcClient to capture inputs + scheduled: dict[str, Any] = {} + + class _FakeClient: + def __init__(self, *args, **kwargs): + pass + + def schedule_new_orchestration( + self, name, *, input=None, instance_id=None, start_at=None, reuse_id_policy=None + ): + scheduled['name'] = name + scheduled['input'] = input + scheduled['instance_id'] = instance_id + scheduled['start_at'] = start_at + scheduled['reuse_id_policy'] = reuse_id_policy + return 'id-1' + + monkeypatch.setattr(client_mod, 'TaskHubGrpcClient', _FakeClient) + + class _TracingClient(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': uuid.uuid4().hex} + if isinstance(request.input, dict) and 'tracing' not in request.input: + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + client = DaprWorkflowClient(interceptors=[_TracingClient()]) + + # We only need a callable with a __name__ for scheduling + def wf(ctx): + yield 'noop' + + wf.__name__ = 'inject_test' + instance_id = client.schedule_new_workflow(wf, input={'a': 1}) + assert instance_id == 'id-1' + assert scheduled['name'] == 'inject_test' + assert isinstance(scheduled['input'], dict) + assert 'tracing' in scheduled['input'] + assert scheduled['input']['a'] == 1 + + +def test_runtime_restores_tracing_before_user_code(monkeypatch): + import durabletask.worker as worker_mod + + monkeypatch.setattr(worker_mod, 'TaskHubGrpcWorker', _FakeWorker) + + seen: dict[str, Any] = {} + + class _TracingRuntime(RuntimeInterceptor): + def execute_workflow(self, request, next): # type: ignore[override] + # no-op; real restoration is app concern; test just ensures input contains tracing + return next(request) + + def execute_activity(self, request, next): # type: ignore[override] + return next(request) + + class _TracingClient2(ClientInterceptor): + def schedule_new_workflow(self, request, next): # type: ignore[override] + tr = {'trace_id': 't1'} + if isinstance(request.input, dict): + request = type(request)( + workflow_name=request.workflow_name, + input={**request.input, 'tracing': tr}, + instance_id=request.instance_id, + start_at=request.start_at, + reuse_id_policy=request.reuse_id_policy, + ) + return next(request) + + rt = WorkflowRuntime( + runtime_interceptors=[_TracingRuntime()], + ) + + @rt.workflow(name='w') + def w(ctx, x): + # The tracing should already be present in input + assert isinstance(x, dict) + assert 'tracing' in x + seen['trace'] = x['tracing'] + yield 'noop' + return 'ok' + + orch = rt._WorkflowRuntime__worker._registry.orchestrators['w'] + # Orchestrator input will have tracing injected via outbound when scheduled as a child or via client + # Here, we directly pass the input simulating schedule with tracing present + gen = orch(_FakeOrchestrationContext(), {'hello': 'world', 'tracing': {'trace_id': 't1'}}) + out = _drive_generator(gen, returned_value='noop') + assert out == 'ok' + assert seen['trace']['trace_id'] == 't1' diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index bf18cd68..f367a43f 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -37,7 +37,10 @@ class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() listOrchestrators.clear() - mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + self.patcher = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self.patcher.start() self.runtime_options = WorkflowRuntime() if hasattr(self.mock_client_wf, '_dapr_alternate_name'): del self.mock_client_wf.__dict__['_dapr_alternate_name'] @@ -48,6 +51,11 @@ def setUp(self): if hasattr(self.mock_client_activity, '_activity_registered'): del self.mock_client_activity.__dict__['_activity_registered'] + def tearDown(self): + """Stop the mock patch to prevent interference with other tests.""" + self.patcher.stop() + mock.patch.stopall() # Ensure all patches are stopped + def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') diff --git a/ext/dapr-ext-workflow/tests/test_workflow_util.py b/ext/dapr-ext-workflow/tests/test_workflow_util.py index 28e92e6c..c1b980ed 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_util.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_util.py @@ -1,3 +1,16 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + import unittest from unittest.mock import patch @@ -7,6 +20,7 @@ class DaprWorkflowUtilTest(unittest.TestCase): + @patch.object(settings, 'DAPR_GRPC_ENDPOINT', '') def test_get_address_default(self): expected = f'{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}' self.assertEqual(expected, getAddress()) diff --git a/mypy.ini b/mypy.ini index 8c0fee4f..7ca609e0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -12,6 +12,7 @@ files = dapr/clients/**/*.py, dapr/conf/**/*.py, dapr/serializers/**/*.py, + ext/dapr-ext-workflow/dapr/ext/workflow/**/*.py, ext/dapr-ext-grpc/dapr/**/*.py, ext/dapr-ext-fastapi/dapr/**/*.py, ext/flask_dapr/flask_dapr/*.py, @@ -19,3 +20,6 @@ files = [mypy-dapr.proto.*] ignore_errors = True + +[mypy-dapr.ext.workflow.*] +python_version = 3.11 diff --git a/pyproject.toml b/pyproject.toml index 0378a8c8..ed9fb11a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,13 @@ target-version = "py310" line-length = 100 fix = true -extend-exclude = [".github", "dapr/proto"] - +extend-exclude = [ + ".github", + "dapr/proto", + "*_pb2.py", + "*_pb2_grpc.py", + "examples/**/.venv", +] [tool.ruff.lint] select = [ "I", # isort diff --git a/tox.ini b/tox.ini index 17041083..0e774eaf 100644 --- a/tox.ini +++ b/tox.ini @@ -2,27 +2,46 @@ skipsdist = True minversion = 3.10.0 envlist = - py{310,311,312,313} + py{310,311,312,313,314} ruff, mypy, +# TODO: switch runner to uv (tox-uv plugin) +runner = virtualenv [testenv] setenv = PYTHONDONTWRITEBYTECODE=1 deps = -rdev-requirements.txt +package = editable commands = coverage run -m unittest discover -v ./tests - coverage run -a -m unittest discover -v ./ext/dapr-ext-workflow/tests + # ext/dapr-ext-workflow uses pytest-based tests + !e2e: coverage run -a -m pytest -q -m "not e2e" ext/dapr-ext-workflow/tests + + # we need to setup dapr or durabletask-go to run e2e tests + # TODO: uncomment after properly setting up dapr or durabletask-go + #e2e: coverage run -a -m pytest -q -m "e2e" ext/dapr-ext-workflow/tests + coverage run -a -m unittest discover -v ./ext/dapr-ext-grpc/tests coverage run -a -m unittest discover -v ./ext/dapr-ext-fastapi/tests coverage run -a -m unittest discover -v ./ext/flask_dapr/tests coverage xml commands_pre = - pip3 install -e {toxinidir}/ - pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ - pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ - pip3 install -e {toxinidir}/ext/dapr-ext-fastapi/ - pip3 install -e {toxinidir}/ext/flask_dapr/ + # TODO: remove this before merging (after durable task is merged) + {envpython} -m pip install -e {toxinidir}/../durabletask-python/ + + {envpython} -m pip install -e {toxinidir}/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-workflow/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-grpc/ + {envpython} -m pip install -e {toxinidir}/ext/dapr-ext-fastapi/ + {envpython} -m pip install -e {toxinidir}/ext/flask_dapr/ +# allow for overriding sidecar ports +pass_env = DAPR_GRPC_ENDPOINT,DAPR_HTTP_ENDPOINT,DAPR_RUNTIME_HOST,DAPR_GRPC_PORT,DAPR_HTTP_PORT,DURABLETASK_GRPC_ENDPOINT + +[flake8] +extend-exclude = .tox,venv,build,dist,dapr/proto,examples/**/.venv +ignore = E203,E501,W503,E701,E704,F821 +max-line-length = 100 [testenv:ruff] basepython = python3 @@ -61,6 +80,9 @@ commands = ./validate.sh jobs ./validate.sh ../ commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ + pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/ @@ -93,6 +115,9 @@ deps = -rdev-requirements.txt commands = mypy --config-file mypy.ini commands_pre = + # TODO: remove this before merging (after durable task is merged) + pip3 install -e {toxinidir}/../durabletask-python/ + pip3 install -e {toxinidir}/ pip3 install -e {toxinidir}/ext/dapr-ext-workflow/ pip3 install -e {toxinidir}/ext/dapr-ext-grpc/