diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 3ed790b..cf7d41f 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -43,7 +43,6 @@ jobs: # Install and run the durabletask-go sidecar for running e2e tests - name: Pytest e2e tests run: | - # TODO: use dapr run instead of durabletask-go as it provides a more reliable sidecar behaviorfor e2e tests go install github.com/dapr/durabletask-go@main durabletask-go --port 4001 & tox -e py${{ matrix.python-version }}-e2e diff --git a/.gitignore b/.gitignore index 9f1046c..07f65c1 100644 --- a/.gitignore +++ b/.gitignore @@ -130,5 +130,6 @@ dmypy.json # IDEs .idea +.vscode coverage.lcov \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 1c929ac..a579b50 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { "[python]": { - "editor.defaultFormatter": "ms-python.autopep8", + "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": "explicit" diff --git a/README.md b/README.md index d4604e0..1943839 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,35 @@ Orchestrations are implemented using ordinary Python functions that take an `Orc Activities are implemented using ordinary Python functions that take an `ActivityContext` as their first parameter. Activity functions are scheduled by orchestrations and have at-least-once execution guarantees, meaning that they will be executed at least once but may be executed multiple times in the event of a transient failure. Activity functions are where the real "work" of any orchestration is done. +#### Async Activities + +Activities can be either synchronous or asynchronous functions. Async activities are useful for I/O-bound operations like HTTP requests, database queries, or file operations: + +```python +from durabletask.task import ActivityContext + +# Synchronous activity +def sync_activity(ctx: ActivityContext, data: str) -> str: + return data.upper() + +# Asynchronous activity +async def async_activity(ctx: ActivityContext, data: str) -> str: + # Perform async I/O operations + async with aiohttp.ClientSession() as session: + async with session.get(f"https://api.example.com/{data}") as response: + result = await response.json() + return result +``` + +Both sync and async activities are registered the same way: + +```python +worker.add_activity(sync_activity) +worker.add_activity(async_activity) +``` + +Orchestrators call them identically regardless of whether they're sync or async - the SDK handles the execution automatically. + ### Durable timers Orchestrations can schedule durable timers using the `create_timer` API. These timers are durable, meaning that they will survive orchestrator restarts and will fire even if the orchestrator is not actively in memory. Durable timers can be of any duration, from milliseconds to months. @@ -150,7 +179,7 @@ Orchestrations can start child orchestrations using the `call_sub_orchestrator` Orchestrations can wait for external events using the `wait_for_external_event` API. External events are useful for implementing human interaction patterns, such as waiting for a user to approve an order before continuing. -### Continue-as-new (TODO) +### Continue-as-new Orchestrations can be continued as new using the `continue_as_new` API. This API allows an orchestration to restart itself from scratch, optionally with a new input. @@ -281,6 +310,9 @@ The following is more information about how to develop this project. Note that d ### Generating protobufs ```sh +# install dev dependencies for generating protobufs and running tests +pip3 install '.[dev]' + make gen-proto ``` @@ -319,9 +351,183 @@ dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/co To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root: ```sh -tox -e py311 -- e2e +tox -e py311-e2e +``` + +### Configuration + +#### Connection Configuration + +The SDK connects to a Durable Task sidecar. By default it uses `localhost:4001`. You can override via environment variables (checked in order): + +- `DAPR_GRPC_ENDPOINT` - Full endpoint (e.g., `localhost:4001`, `grpcs://host:443`) +- `DAPR_GRPC_HOST` (or `DAPR_RUNTIME_HOST`) and `DAPR_GRPC_PORT` - Host and port separately + +Example (common ports: 4001 for DurableTask-Go emulator, 50001 for Dapr sidecar): + +```sh +export DAPR_GRPC_ENDPOINT=localhost:4001 +# or +export DAPR_GRPC_HOST=localhost +export DAPR_GRPC_PORT=50001 +``` + + +#### Async Workflow Configuration + +Configure async workflow behavior and debugging: + +- `DAPR_WF_DISABLE_DETERMINISTIC_DETECTION` - Disable non-determinism detection (set to `true`) + +Example: + +```sh +export DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=false +``` + +### Async workflow authoring + +For a deeper tour of the async authoring surface (determinism helpers, sandbox modes, timeouts, concurrency patterns), see the Async Enhancements guide: [ASYNC_ENHANCEMENTS.md](./ASYNC_ENHANCEMENTS.md). + +You can author orchestrators with `async def` using the new `durabletask.aio` package, which provides a comprehensive async workflow API: + +```python +from durabletask.worker import TaskHubGrpcWorker +from durabletask.aio import AsyncWorkflowContext + +async def my_orch(ctx: AsyncWorkflowContext, input) -> str: + r1 = await ctx.call_activity(act1, input=input) + await ctx.sleep(1.0) + r2 = await ctx.call_activity(act2, input=r1) + return r2 + +with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(my_orch) +``` + +The sandbox (enabled by default) patches standard Python functions to deterministic equivalents during workflow execution. This allows natural async code like `asyncio.sleep()`, `random.random()`, and `asyncio.gather()` to work correctly with workflow replay. Three modes are available: + +- `"best_effort"` (default): Patches functions, minimal overhead +- `"strict"`: Patches + blocks dangerous operations (file I/O, `asyncio.create_task`) +- `"off"`: No patching (requires manual use of `ctx.*` methods everywhere) + +> **Enhanced Sandbox Features**: The enhanced version includes comprehensive non-determinism detection, timeout support, enhanced concurrency primitives, and debugging tools. See [ASYNC_ENHANCEMENTS.md](./durabletask/aio/ASYNCIO_ENHANCEMENTS.md) for complete documentation. + +#### Async patterns + +- Activities and sub-orchestrations can be referenced by function object or by their registered string name. Both forms are supported: +- Function reference (preferred for IDE/type support) or string name (useful across modules/languages). + +- Activities: +```python +result = await ctx.call_activity("process", input={"x": 1}) +# or: result = await ctx.call_activity(process, input={"x": 1}) ``` +- Timers: +```python +await ctx.sleep(1.5) # seconds or timedelta +``` + +- External events: +```python +val = await ctx.wait_for_external_event("approval") +``` + +- Concurrency: +```python +t1 = ctx.call_activity("a"); t2 = ctx.call_activity("b") +# when_all waits for all tasks and returns results in order +results = await ctx.when_all([t1, t2]) +# when_any returns (index, result) tuple of first completed task +idx, result = await ctx.when_any([ctx.wait_for_external_event("x"), ctx.create_timer(5)]) +``` + +#### Async vs. generator API differences + +- Async authoring (`durabletask.aio`): awaiting returns the operation's value. Exceptions are raised on `await` (no `is_failed`). +- Generator authoring (`durabletask.task`): yielding returns `Task` objects. Use `get_result()` to read values; failures surface via `is_failed()` or by raising on `get_result()`. + +Examples: + +```python +# Async authoring (await returns value) +# when_any returns a proxy that compares equal to the original awaitable +# and exposes get_result() for the completed item. +approval = ctx.wait_for_external_event("approval") +winner = await ctx.when_any([approval, ctx.sleep(60)]) +if winner == approval: + details = winner.get_result() +``` + +```python +# Async authoring (index + result) +idx, result = await ctx.when_any_with_result([approval, ctx.sleep(60)]) +if idx == 0: # approval won + details = result +``` + +```python +# Generator authoring (yield returns Task) +approval = ctx.wait_for_external_event("approval") +winner = yield task.when_any([approval, ctx.create_timer(timedelta(seconds=60))]) +if winner == approval: + details = approval.get_result() +``` + +Failure handling in async: + +```python +try: + val = await ctx.call_activity("might_fail") +except Exception as e: + # handle failure branch + ... +``` + +- Sub-orchestrations (function reference or registered name): +```python +out = await ctx.call_sub_orchestrator(child_fn, input=payload) +# or: out = await ctx.call_sub_orchestrator("child", input=payload) +``` + +- Deterministic utilities: +```python +now = ctx.now(); rid = ctx.random().random(); uid = ctx.uuid4() +``` + +- Cross-app activity/sub-orchestrator routing (async only for now): +```python +# Route activity to a different app via app_id +result = await ctx.call_activity("process", input=data, app_id="worker-app-2") + +# Route sub-orchestrator to a different app +child_result = await ctx.call_sub_orchestrator("child_workflow", input=data, app_id="orchestrator-app-2") +``` +Notes: +- The `app_id` parameter enables multi-app orchestrations where activities or child workflows run in different application instances. +- Requires sidecar support for cross-app invocation. + +#### Worker readiness + +When starting a worker and scheduling immediately, wait for the connection to the sidecar to be established: + +```python +with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(my_orch) + worker.start() + worker.wait_for_ready(timeout=5) + # Now safe to schedule +``` + +#### Suspension & termination + +- `ctx.is_suspended` reflects suspension state during replay/processing. +- Suspend pauses progress without raising inside async orchestrators. +- Terminate completes with `TERMINATED` status; use client APIs to terminate/resume. + - Only new events are buffered while suspended; replay events continue to apply to rebuild local state deterministically. + + ## Contributing This project welcomes contributions and suggestions. Most contributions require you to agree to a diff --git a/dev-requirements.txt b/dev-requirements.txt index ba589ab..e69de29 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1 +0,0 @@ -grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code \ No newline at end of file diff --git a/durabletask/__init__.py b/durabletask/__init__.py index 78ea7ca..1fe82f0 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# Public async exports (import directly from durabletask.aio) +from durabletask.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner # noqa: F401 + """Durable Task SDK for Python""" PACKAGE_NAME = "durabletask" diff --git a/durabletask/aio/ASYNCIO_ENHANCEMENTS.md b/durabletask/aio/ASYNCIO_ENHANCEMENTS.md new file mode 100644 index 0000000..d5269bd --- /dev/null +++ b/durabletask/aio/ASYNCIO_ENHANCEMENTS.md @@ -0,0 +1,240 @@ +# Enhanced Async Workflow Features + +This document describes the enhanced async workflow capabilities added to this fork of durabletask-python. For a deep dive into architecture and internals, see [ASYNCIO_INTERNALS.md](ASYNCIO_INTERNALS.md). + +## Overview + +The durabletask-python SDK includes comprehensive async workflow enhancements, providing a production-ready async authoring experience with advanced debugging, error handling, and determinism enforcement. This works seamlessly with the existing python workflow authoring experience. + +## Quick Start + +```python +from durabletask.worker import TaskHubGrpcWorker +from durabletask.aio import AsyncWorkflowContext, SandboxMode + +async def enhanced_workflow(ctx: AsyncWorkflowContext, input_data) -> str: + # Enhanced error handling with rich context + try: + result = await ctx.with_timeout( + ctx.call_activity("my_activity", input=input_data), + 30.0, # 30 second timeout + ) + except TimeoutError: + result = "Activity timed out" + + # Enhanced concurrency with result indexing + tasks = [ctx.call_activity(f"task_{i}") for i in range(3)] + completed_index, first_result = await ctx.when_any_with_result(tasks) + + # Deterministic operations + current_time = ctx.now() + random_value = ctx.random().random() + unique_id = ctx.uuid4() + + return { + "result": result, + "first_completed": completed_index, + "timestamp": current_time.isoformat(), + "random": random_value, + "id": str(unique_id) + } + +# Register with enhanced features +with TaskHubGrpcWorker() as worker: + # Async orchestrators are auto-detected - both forms work: + worker.add_orchestrator(enhanced_workflow) # Auto-detects async + + # Or specify sandbox mode explicitly: + worker.add_orchestrator( + enhanced_workflow, + sandbox_mode=SandboxMode.BEST_EFFORT # or "best_effort" + ) + + worker.start() + # ... rest of your code +``` + +## Enhanced Features + +### 1. **Advanced Error Handling** +- `AsyncWorkflowError` with rich context (instance ID, workflow name, step) +- Enhanced error messages with actionable suggestions +- Better exception propagation and debugging support + +### 2. **Non-Determinism Detection** +- Automatic detection of non-deterministic function calls +- Three modes: `"best_effort"` (default), `"strict"` (errors), `"off"` (no patching) +- Comprehensive coverage of problematic functions +- Helpful suggestions for deterministic alternatives + +### 3. **Enhanced Concurrency Primitives** +- `when_all()` - Waits for all tasks to complete and returns list of results in order +- `when_any()` - Returns (index, result) tuple indicating which task completed first +- `with_timeout()` - Add timeout to any operation + +### 4. **Async Context Management** +- Full async context manager support (`async with ctx:`) +- Cleanup task registry with `ctx.add_cleanup()` +- Automatic resource cleanup + +### 5. **Debugging and Monitoring** +- Operation history tracking when debug mode is enabled +- `ctx.get_debug_info()` for workflow introspection +- Enhanced logging with operation details + +### 6. **Performance Optimizations** +- `__slots__` on all awaitable classes for memory efficiency +- Optimized hot paths in coroutine-to-generator bridge +- Reduced object allocations + +### 7. **Enhanced Sandboxing** +- Extended coverage of non-deterministic functions +- Strict mode blocks for dangerous operations +- Better patching of time, random, and UUID functions + +### 8. **Type Safety** +- Runtime validation of workflow functions +- Enhanced type annotations +- `WorkflowFunction` protocol for better IDE support + +## Registration + +Async orchestrators are automatically detected when using `add_orchestrator()`: + +```python +from durabletask.aio import SandboxMode + +# Auto-detection - simplest form +worker.add_orchestrator(my_async_workflow) + +# With explicit sandbox mode +worker.add_orchestrator( + my_async_workflow, + sandbox_mode=SandboxMode.BEST_EFFORT # or "best_effort" string +) +``` + +Note: The `sandbox_mode` parameter accepts both `SandboxMode` enum values and string literals (`"off"`, `"best_effort"`, `"strict"`). + +## Sandbox Modes + +Control non-determinism detection with the `sandbox_mode` parameter: + +```python +# Default: Patches asyncio functions for determinism, optional warnings +worker.add_orchestrator(workflow) # Uses "best_effort" by default + +# Development: Same as default, warnings when debug mode enabled +worker.add_orchestrator(workflow, sandbox_mode=SandboxMode.BEST_EFFORT) + +# Testing: Errors for non-deterministic calls +worker.add_orchestrator(workflow, sandbox_mode=SandboxMode.STRICT) + +# No patching: Use only if all code uses ctx.* methods explicitly +worker.add_orchestrator(workflow, sandbox_mode="off") +``` + +Why "best_effort" is the default: +- Makes standard asyncio patterns work correctly (asyncio.sleep, asyncio.gather, etc.) +- Patches random/time/uuid to be deterministic automatically +- Optional warnings only when debug mode is enabled (low overhead) +- Provides "pit of success" for async workflow authoring + +### Performance Impact +- `"best_effort"` (default): Minimal overhead from function patching. Tracing overhead present but uses lightweight noop tracer unless debug mode is enabled. +- `"strict"`: ~100-200% overhead due to full Python tracing for detection +- `"off"`: Zero overhead (no patching, no tracing) +- Global disable: Set `DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true` environment variable + +Note: Function patching overhead is minimal (single-digit percentage). Tracing overhead (when enabled) is more significant due to Python's sys.settrace() mechanism. + +## Environment Variables + +- `DAPR_WF_DEBUG=true` / `DT_DEBUG=true` - Enable debug logging, operation tracking, and non-determinism warnings +- `DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true` - Globally disable non-determinism detection + +## Developer Mode + +Set `DAPR_WF_DEBUG=true` during development to enable: +- Non-determinism warnings for problematic function calls +- Detailed operation logging and debugging information +- Enhanced error messages with suggested alternatives + +```bash +# Enable developer warnings +export DAPR_WF_DEBUG=true +python your_workflow.py + +# Production mode (no warnings, optimal performance) +unset DAPR_WF_DEBUG +python your_workflow.py +``` + +This approach is similar to tools like mypy - rich feedback during development, zero runtime overhead in production. + +## Examples + +### Timeout Support +```python +from durabletask.aio import AsyncWorkflowContext + +async def workflow_with_timeout(ctx: AsyncWorkflowContext, input_data) -> str: + try: + result = await ctx.with_timeout( + ctx.call_activity("slow_activity"), + 10.0, # timeout first + ) + except TimeoutError: + result = "Operation timed out" + return result +``` + +### when_any with index and result + +```python +async def competitive_workflow(ctx, input_data): + tasks = [ + ctx.call_activity("provider_a"), + ctx.call_activity("provider_b"), + ctx.call_activity("provider_c") + ] + + # when_any returns (index, result) tuple + winner_index, result = await ctx.when_any(tasks) + return f"Provider {winner_index} won with: {result}" +``` + +### Error Handling with Context + +```python +async def robust_workflow(ctx, input_data): + try: + return await ctx.call_activity("risky_activity") + except Exception as e: + # Enhanced error will include workflow context + debug_info = ctx._get_info_snapshot() + return {"error": str(e), "debug": debug_info} +``` + +## Best Practices + +1. **Use deterministic alternatives**: + - `ctx.now()` instead of `datetime.now()` (async workflows) + - `context.current_utc_datetime` instead of `datetime.now()` (generator/non-async) + - `ctx.random()` instead of `random` + - `ctx.uuid4()` instead of `uuid.uuid4()` + +2. **Use strict mode in testing**: + ```python + sandbox_mode = "strict" if os.getenv("CI") else "best_effort" + ``` + +3. **Add timeouts to external operations**: + ```python + result = await ctx.with_timeout(ctx.call_activity("external_api"), 30.0) + ``` + +4. **Enable debug mode during development**: + ```bash + export DAPR_WF_DEBUG=true + ``` diff --git a/durabletask/aio/ASYNCIO_INTERNALS.md b/durabletask/aio/ASYNCIO_INTERNALS.md new file mode 100644 index 0000000..b9401d6 --- /dev/null +++ b/durabletask/aio/ASYNCIO_INTERNALS.md @@ -0,0 +1,308 @@ +# Durable Task AsyncIO Internals + +This document explains how the AsyncIO implementation in this repository integrates with the existing generator‑based Durable Task runtime. It covers the coroutine→generator bridge, awaitable design, sandboxing and non‑determinism detection, error/cancellation semantics, debugging, and guidance for extending the system. + +## Scope and Goals + +- Async authoring model for orchestrators while preserving Durable Task's generator runtime contract +- Deterministic execution and replay correctness first +- Optional, scoped compatibility sandbox for common stdlib calls during development/test +- Minimal surface area changes to core non‑async code paths + +Key modules: +- `durabletask/aio/context.py` — Async workflow context and deterministic utilities +- `durabletask/aio/driver.py` — Coroutine→generator bridge +- `durabletask/aio/sandbox.py` — Scoped patching and non‑determinism detection + +## Architecture Overview + +### Coroutine→Generator Bridge + +Async orchestrators are authored as `async def` but executed by Durable Task as generators that yield `durabletask.task.Task` (or composite) instances. The bridge implements a driver that manually steps a coroutine and converts each `await` into a yielded Durable Task operation. + +High‑level flow: +1. `TaskHubGrpcWorker.add_async_orchestrator(async_fn, sandbox_mode=...)` wraps `async_fn` with a `CoroutineOrchestratorRunner` and registers a generator orchestrator with the worker. +2. At execution time, the runtime calls the registered generator orchestrator with a base `OrchestrationContext` and input. +3. The generator orchestrator constructs `AsyncWorkflowContext` and then calls `runner.to_generator(async_fn_ctx, input)` to obtain a generator. +4. The driver loop yields Durable Task operations to the engine and sends results back into the coroutine upon resume, until the coroutine completes. + +Driver responsibilities: +- Prime the coroutine (`coro.send(None)`) and handle immediate completion +- Recognize awaitables whose `__await__` yield driver‑recognized operation descriptors +- Yield the underlying Durable Task `task.Task` (or composite) to the engine +- Translate successful completions to `.send(value)` and failures to `.throw(exc)` on the coroutine +- Normalize `StopIteration` completions (PEP 479) so that orchestrations complete with a value rather than raising into the worker + +### Awaitables and Operation Descriptors + +Awaitables in `durabletask.aio` implement `__await__` to expose a small operation descriptor that the driver understands. Each descriptor maps deterministically to a Durable Task operation: + +- Activity: `ctx.activity(name, *, input)` → `task.call_activity(name, input)` +- Sub‑orchestrator: `ctx.sub_orchestrator(fn_or_name, *, input)` → `task.call_sub_orchestrator(...)` +- Timer: `ctx.sleep(duration)` → `task.create_timer(fire_at)` +- External event: `ctx.wait_for_external_event(name)` → `task.wait_for_external_event(name)` +- Concurrency: `ctx.when_all([...])` / `ctx.when_any([...])` → `task.when_all([...])` / `task.when_any([...])` + +Design rules: +- Awaitables are single‑use. Each call creates a fresh awaitable whose `__await__` returns a fresh iterator. This avoids "cannot reuse already awaited coroutine" during replay. +- All awaitables use `__slots__` for memory efficiency and replay stability. +- Composite awaitables convert their children to Durable Task tasks before yielding. + +### AsyncWorkflowContext + +`AsyncWorkflowContext` wraps the base generator `OrchestrationContext` and exposes deterministic utilities and async awaitables. + +Provided utilities (deterministic): +- `now()` — orchestration time based on history +- `random()` — PRNG seeded deterministically (e.g., instance/run ID); used by `uuid4()` +- `uuid4()` — derived from deterministic PRNG +- `is_replaying`, `is_suspended`, `workflow_name`, `instance_id`, etc. — passthrough metadata + +Concurrency: +- `when_all([...])` returns an awaitable that completes with a list of results +- `when_any([...])` returns an awaitable that completes with the first completed child +- `when_any_with_result([...])` returns `(index, result)` +- `with_timeout(awaitable, seconds|timedelta)` wraps any awaitable with a deterministic timer + +Debugging helpers (dev‑only): +- Operation history when debug is enabled (`DAPR_WF_DEBUG=true` or `DT_DEBUG=true`) +- `get_debug_info()` to inspect state for diagnostics + +### Error and Cancellation Semantics + +- Activity/sub‑orchestrator completion values are sent back into the coroutine. Final failures are injected via `coro.throw(...)`. +- Cancellations are mapped to `asyncio.CancelledError` where appropriate and thrown into the coroutine. +- Termination completes orchestrations with TERMINATED status (matching generator behavior); exceptions are surfaced as failureDetails in the runtime completion action. +- The driver consumes `StopIteration` from awaited iterators and returns the value to avoid leaking `RuntimeError("generator raised StopIteration")`. + +## Sequence Diagram + +### Mermaid (rendered in compatible viewers) + +```mermaid +sequenceDiagram + autonumber + participant W as TaskHubGrpcWorker + participant E as Durable Task Engine + participant G as Generator Orchestrator Wrapper + participant R as CoroutineOrchestratorRunner + participant C as Async Orchestrator (coroutine) + participant A as Awaitable (__await__) + participant S as Sandbox (optional) + + E->>G: invoke(name, ctx, input) + G->>R: to_generator(AsyncWorkflowContext(ctx), input) + R->>C: start coroutine (send None) + + opt sandbox_mode != "off" + G->>S: enter sandbox scope (patch) + S-->>G: patch asyncio.sleep/random/uuid/time + end + + Note right of C: await ctx.activity(...), ctx.sleep(...), ctx.when_any(...) + C-->>A: create awaitable + A-->>R: __await__ yields Durable Task op + R-->>E: yield task/composite + E-->>R: resume with result/failure + R->>C: send(result) / throw(error) + C-->>R: next awaitable or StopIteration + + alt next awaitable + R-->>E: yield next operation + else completed + R-->>G: return result (StopIteration.value) + G-->>E: completeOrchestration(result) + end + + opt sandbox_mode != "off" + G->>S: exit sandbox scope (restore) + end +``` + +### ASCII Flow (fallback) + +```text +Engine → Wrapper → Runner → Coroutine + │ │ │ ├─ await ctx.activity(...) + │ │ │ ├─ await ctx.sleep(...) + │ │ │ └─ await ctx.when_any([...]) + │ │ │ + │ │ └─ Awaitable.__await__ → yields Durable Task op + │ └─ yield op → Engine schedules/waits + └─ resume with result → Runner.send/throw → Coroutine step + +Loop until coroutine returns → Runner captures StopIteration.value → +Wrapper returns value → Engine emits completeOrchestration + +Optional Sandbox (per activation): + enter → patch asyncio.sleep/random/uuid/time → run step → restore +``` + +## Sandboxing and Non‑Determinism Detection + +The sandbox provides scoped compatibility and detection for common non‑deterministic stdlib calls. It is configured per orchestrator via `sandbox_mode`: + +- `best_effort` (default): Patch common functions within a scope and emit warnings on detected non‑determinism when debug mode is enabled. +- `strict`: Patch common functions and raise `SandboxViolationError` on detected calls. +- `off`: No patching or detection; zero overhead. Use deterministic APIs only. + +Patched targets (best‑effort and strict): +- `asyncio.sleep` → deterministic timer awaitable +- `asyncio.gather` → replay-safe one-shot awaitable wrapper using WhenAllAwaitable +- `random` module functions (random, randrange, randint, getrandbits via deterministic PRNG) +- `uuid.uuid4` → derived from deterministic PRNG +- `time.time/time_ns` → orchestration time + +Additional blocks in strict mode only: +- `asyncio.create_task` → raises SandboxViolationError +- `builtins.open` → raises SandboxViolationError +- `os.urandom` → raises SandboxViolationError +- `secrets.token_bytes/token_hex` → raises SandboxViolationError + +Important limitations: +- `datetime.datetime.now()` is not patched (type immutability). Use `ctx.now()` or `ctx.current_utc_datetime`. +- `from x import y` may bypass patches due to direct binding. +- Modules that cache callables at import time won’t see patch updates. +- This does not make I/O deterministic; all external I/O must be in activities. + +Detection engine: +- `_NonDeterminismDetector` tracks suspicious call sites using Python frame inspection +- Deduplicates warnings per call signature and location +- In strict mode, raises `SandboxViolationError` with actionable suggestions; in best‑effort, issues `NonDeterminismWarning` + +### Detector: What, When, and Why + +What it checks: +- Calls to common non‑deterministic functions (e.g., `time.time`, `random.random`, `uuid.uuid4`, `os.urandom`, `secrets.*`, `datetime.utcnow`) in user code +- Uses a lightweight global trace function (installed only in `best_effort` or `strict`) to inspect call frames and identify risky callsites +- Skips internal `durabletask` frames and built‑ins to reduce noise + +Modes and behavior: +- `SandboxMode.OFF`: + - No tracing, no patching, zero overhead + - Detector is not active +- `SandboxMode.BEST_EFFORT` (default): + - Patches selected stdlib functions (asyncio.sleep, random, uuid.uuid4, time.time, asyncio.gather) + - Installs tracer only when `ctx._debug_mode` is true; otherwise no tracer (minimal overhead) + - Emits `NonDeterminismWarning` once per unique callsite with a suggested deterministic alternative +- `SandboxMode.STRICT`: + - Patches selected stdlib functions and blocks dangerous operations (e.g., `open`, `os.urandom`, `secrets.*`, `asyncio.create_task`) + - Installs full tracer regardless of debug flag + - Raises `SandboxViolationError` on first detection with details and suggestions + +When to use each mode: +- `BEST_EFFORT` (default): Recommended for most use cases. Patches make standard asyncio patterns work correctly with minimal overhead. +- `STRICT`: Use in CI/testing to enforce determinism and catch violations early. +- `OFF`: Use only if you're certain all code uses `ctx.*` methods exclusively and want absolute zero overhead. + +Note: `BEST_EFFORT` is now the default because it makes workflows "just work" with standard asyncio code patterns. + +Enabling and controlling the detector: +- Per‑orchestrator registration: +```python +from durabletask.aio import SandboxMode + +worker.add_orchestrator(my_async_orch, sandbox_mode=SandboxMode.BEST_EFFORT) +``` +- Scoped usage in advanced scenarios: +```python +from durabletask.aio import sandbox_best_effort + +async def my_async_orch(ctx, _): + with sandbox_best_effort(ctx): + # code here benefits from patches + detection + ... +``` +- Debug gating (best_effort only): set `DAPR_WF_DEBUG=true` (or `DT_DEBUG=true`) to enable full detection; otherwise a no‑op tracer is used to minimize overhead. +- Global disable (regardless of mode): set `DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true` to force `OFF` behavior without changing code. + +What warnings/errors look like: +- Warning (`BEST_EFFORT`): + - Category: `NonDeterminismWarning` + - Message includes function name, filename:line, the current function, and a deterministic alternative (e.g., “Use `ctx.now()` instead of `datetime.utcnow()`). +- Error (`STRICT`): + - Exception: `SandboxViolationError` + - Includes violation type, suggested alternative, `workflow_name`, and `instance_id` when available + +Overhead and performance: +- `OFF`: zero overhead (no patching, no detection) +- `BEST_EFFORT` (default): minimal overhead from patching; lightweight noop tracer unless debug mode enabled (full detection tracer only when `DAPR_WF_DEBUG=true`) +- `STRICT`: ~100-200% overhead due to full Python tracing; recommended only for testing/enforcement + +Note: The patching overhead (module-level function replacement) is minimal. The tracing overhead (sys.settrace) is more significant when full detection is enabled. + +Limitations and caveats: +- Direct imports like `from random import random` bind the function and may bypass patching +- Libraries that cache function references at import time will not see patch changes +- `datetime.datetime.now()` cannot be patched; use `ctx.now()` instead +- The detector is advisory; it cannot prove determinism for arbitrary code. Treat it as a power tool for finding common pitfalls, not a formal verifier + +Quick mapping of alternatives: +- `datetime.now/utcnow` → `ctx.now()` (async) or `ctx.current_utc_datetime` +- `time.time/time_ns` → `ctx.now().timestamp()` / `int(ctx.now().timestamp() * 1e9)` +- `random.*` → `ctx.random().*` +- `uuid.uuid4` → `ctx.uuid4()` +- `os.urandom` / `secrets.*` → `ctx.random().randbytes()` (or move to an activity) + +Troubleshooting tips: +- Seeing repeated warnings? They are deduplicated per callsite; different files/lines will warn independently +- Unexpected strict errors during replay? Confirm you are not creating background tasks (`asyncio.create_task`) or performing I/O in the orchestrator +- Need to quiet a test temporarily? Use `sandbox_mode=SandboxMode.OFF` for that orchestrator or `DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true` during the run + +## Integration with Generator Runtime + +- Registration: `TaskHubGrpcWorker.add_async_orchestrator(async_fn, ...)` registers a generator wrapper that delegates to the driver. Generator orchestrators and async orchestrators can coexist. +- Execution loop remains owned by Durable Task; the driver only yields operations and processes resumes. +- Replay: The driver and awaitables are designed to be idempotent and to avoid reusing awaited iterators; orchestration state is reconstructed deterministically from history. + +## Debugging Guide + +Enable developer diagnostics: +- Set `DAPR_WF_DEBUG=true` (or `DT_DEBUG=true`) to enable operation logging and non‑determinism warnings. +- Use `ctx.get_debug_info()` to export state, operations, and instance metadata. + +Common issues: +- "coroutine was never awaited": Ensure all workflow operations are awaited and that no background tasks are spawned (`asyncio.create_task` is blocked in strict mode). +- "cannot reuse already awaited coroutine": Do not cache awaitables across activations; create them inline. All awaitables in this package are single‑use by design. +- Orchestration hangs: Inspect last yielded operation in logs; verify that the corresponding history event occurs (activity completion, timer fired, external event received). For external events, ensure the event name matches exactly. +- Sandbox leakage: Verify patches are scoped by context manager and restored after activation. Avoid `from x import y` forms in orchestrator code when relying on patching. + +Runtime tracing tips: +- Log each yielded operation and each resume result in the driver (behind debug flag) to correlate with sidecar logs. +- Capture `instance_id` and `history_event_sequence` from `AsyncWorkflowContext` when logging. + +## Performance Characteristics + +- `sandbox_mode="off"`: zero overhead vs generator orchestrators +- `best_effort` / `strict`: additional overhead from Python tracing and patching; use during development and testing +- Awaitables use `__slots__` and avoid per‑step allocations in hot paths where feasible + +## Extending the System + +Adding a new awaitable: +1. Define a class with `__slots__` and a constructor capturing required arguments. +2. Implement `_to_task(self) -> durabletask.task.Task` that builds the deterministic operation. +3. Implement `__await__` to yield the driver‑recognized descriptor (or directly the task, depending on driver design). +4. Add unit tests for replay stability and error propagation. + +Adding sandbox coverage: +1. Add patch/unpatch logic inside `sandbox.py` with correct scoping and restoration. +2. Update `_NonDeterminismDetector` patterns and suggestions. +3. Document limitations and add tests for best‑effort and strict modes. + +## Interop Checklist (Async ↔ Generator) + +- Activities: identical behavior; only authoring differs (`yield` vs `await`). Activities themselves can be either sync or async functions and work identically from both generator and async orchestrators. +- Timers: map to the same `createTimer` actions. +- External events: same semantics for buffering and completion. +- Sub‑orchestrators: same create/complete/fail events. +- Suspension/Termination: same runtime events; async path observes `is_suspended` and maps termination to completion with TERMINATED. + +## References + +- `durabletask/aio/context.py` +- `durabletask/aio/driver.py` +- `durabletask/aio/sandbox.py` +- Tests under `tests/durabletask/` and `tests/aio/` + + diff --git a/durabletask/aio/__init__.py b/durabletask/aio/__init__.py index d446228..1d2cd32 100644 --- a/durabletask/aio/__init__.py +++ b/durabletask/aio/__init__.py @@ -1,5 +1,75 @@ +# Deterministic utilities +from durabletask.deterministic import ( + DeterminismSeed, + DeterministicContextMixin, + derive_seed, + deterministic_random, + deterministic_uuid4, +) + +# Awaitable classes +from .awaitables import ( + ActivityAwaitable, + AwaitableBase, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + SwallowExceptionAwaitable, + TimeoutAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) from .client import AsyncTaskHubGrpcClient +# Compatibility protocol (core functionality only) +from .compatibility import OrchestrationContextProtocol, ensure_compatibility + +# Core context and driver +from .context import AsyncWorkflowContext +from .driver import CoroutineOrchestratorRunner, WorkflowFunction + +# Sandbox and error handling +from .errors import ( + AsyncWorkflowError, + NonDeterminismWarning, + SandboxViolationError, + WorkflowTimeoutError, + WorkflowValidationError, +) +from .sandbox import SandboxMode, _NonDeterminismDetector + __all__ = [ "AsyncTaskHubGrpcClient", + # Core classes + "AsyncWorkflowContext", + "CoroutineOrchestratorRunner", + "WorkflowFunction", + # Deterministic utilities + "DeterministicContextMixin", + "DeterminismSeed", + "derive_seed", + "deterministic_random", + "deterministic_uuid4", + # Awaitable classes + "AwaitableBase", + "ActivityAwaitable", + "SubOrchestratorAwaitable", + "SleepAwaitable", + "ExternalEventAwaitable", + "WhenAllAwaitable", + "WhenAnyAwaitable", + "TimeoutAwaitable", + "SwallowExceptionAwaitable", + # Sandbox and utilities + "SandboxMode", + "_NonDeterminismDetector", + # Compatibility protocol + "OrchestrationContextProtocol", + "ensure_compatibility", + # Exceptions + "AsyncWorkflowError", + "NonDeterminismWarning", + "WorkflowTimeoutError", + "WorkflowValidationError", + "SandboxViolationError", ] diff --git a/durabletask/aio/awaitables.py b/durabletask/aio/awaitables.py new file mode 100644 index 0000000..31c3000 --- /dev/null +++ b/durabletask/aio/awaitables.py @@ -0,0 +1,548 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Awaitable classes for async workflows. + +This module provides awaitable wrappers for DurableTask operations that can be +used in async workflows. Each awaitable yields a durabletask.task.Task which +the driver yields to the runtime and feeds the result back to the coroutine. +""" + +from __future__ import annotations + +import importlib +from datetime import datetime, timedelta +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + TypeVar, + Union, + cast, +) + +from durabletask import task + +# Forward reference for the operation wrapper - imported at runtime to avoid circular imports + +TOutput = TypeVar("TOutput") + + +class AwaitableBase(Awaitable[TOutput]): + """ + Base class for all workflow awaitables. + + Provides the common interface for converting workflow operations + into DurableTask tasks that can be yielded to the runtime. + """ + + __slots__ = () + + def _to_task(self) -> task.Task[Any]: + """ + Convert this awaitable to a DurableTask task. + + Subclasses must implement this method to define how they + translate to the underlying task system. + + Returns: + A DurableTask task representing this operation + """ + raise NotImplementedError("Subclasses must implement _to_task") + + def __await__(self) -> Generator[Any, Any, TOutput]: + """ + Make this object awaitable by yielding the underlying task. + + This is called when the awaitable is used with 'await' in an + async workflow function. + """ + # Yield the task directly - the worker expects durabletask.task.Task objects + t = self._to_task() + result = yield t + return cast(TOutput, result) + + +class ActivityAwaitable(AwaitableBase[TOutput]): + """Awaitable for activity function calls.""" + + __slots__ = ("_ctx", "_activity_fn", "_input", "_retry_policy", "_app_id", "_metadata") + + def __init__( + self, + ctx: Any, + activity_fn: Union[Callable[..., Any], str], + *, + input: Any = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ): + """ + Initialize an activity awaitable. + + Args: + ctx: The workflow context + activity_fn: The activity function to call + input: Input data for the activity + retry_policy: Optional retry policy + app_id: Optional target app ID for routing + metadata: Optional metadata for the activity call + """ + super().__init__() + self._ctx = ctx + self._activity_fn = activity_fn + self._input = input + self._retry_policy = retry_policy + self._app_id = app_id + self._metadata = metadata + + def _to_task(self) -> task.Task[Any]: + """Convert to a call_activity task.""" + # Check if the context supports metadata parameter + import inspect + + sig = inspect.signature(self._ctx.call_activity) + supports_metadata = "metadata" in sig.parameters + supports_app_id = "app_id" in sig.parameters + + if self._retry_policy is None: + if (supports_metadata and self._metadata is not None) or ( + supports_app_id and self._app_id is not None + ): + kwargs: Dict[str, Any] = {"input": self._input} + if supports_metadata and self._metadata is not None: + kwargs["metadata"] = self._metadata + if supports_app_id and self._app_id is not None: + kwargs["app_id"] = self._app_id + return cast(task.Task[Any], self._ctx.call_activity(self._activity_fn, **kwargs)) + else: + return cast( + task.Task[Any], self._ctx.call_activity(self._activity_fn, input=self._input) + ) + else: + if (supports_metadata and self._metadata is not None) or ( + supports_app_id and self._app_id is not None + ): + kwargs2: Dict[str, Any] = {"input": self._input, "retry_policy": self._retry_policy} + if supports_metadata and self._metadata is not None: + kwargs2["metadata"] = self._metadata + if supports_app_id and self._app_id is not None: + kwargs2["app_id"] = self._app_id + return cast( + task.Task[Any], + self._ctx.call_activity( + self._activity_fn, + **kwargs2, + ), + ) + else: + return cast( + task.Task[Any], + self._ctx.call_activity( + self._activity_fn, + input=self._input, + retry_policy=self._retry_policy, + ), + ) + + +class SubOrchestratorAwaitable(AwaitableBase[TOutput]): + """Awaitable for sub-orchestrator calls.""" + + __slots__ = ( + "_ctx", + "_workflow_fn", + "_input", + "_instance_id", + "_retry_policy", + "_app_id", + "_metadata", + ) + + def __init__( + self, + ctx: Any, + workflow_fn: Union[Callable[..., Any], str], + *, + input: Any = None, + instance_id: Optional[str] = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ): + """ + Initialize a sub-orchestrator awaitable. + + Args: + ctx: The workflow context + workflow_fn: The sub-orchestrator function to call + input: Input data for the sub-orchestrator + instance_id: Optional instance ID for the sub-orchestrator + retry_policy: Optional retry policy + app_id: Optional target app ID for routing + metadata: Optional metadata for the sub-orchestrator call + """ + super().__init__() + self._ctx = ctx + self._workflow_fn = workflow_fn + self._input = input + self._instance_id = instance_id + self._retry_policy = retry_policy + self._app_id = app_id + self._metadata = metadata + + def _to_task(self) -> task.Task[Any]: + """Convert to a call_sub_orchestrator task.""" + # The underlying context uses call_sub_orchestrator (durabletask naming) + # Check if the context supports metadata parameter + import inspect + + sig = inspect.signature(self._ctx.call_sub_orchestrator) + supports_metadata = "metadata" in sig.parameters + supports_app_id = "app_id" in sig.parameters + + if self._retry_policy is None: + if (supports_metadata and self._metadata is not None) or ( + supports_app_id and self._app_id is not None + ): + kwargs: Dict[str, Any] = {"input": self._input, "instance_id": self._instance_id} + if supports_metadata and self._metadata is not None: + kwargs["metadata"] = self._metadata + if supports_app_id and self._app_id is not None: + kwargs["app_id"] = self._app_id + return cast( + task.Task[Any], + self._ctx.call_sub_orchestrator( + self._workflow_fn, + **kwargs, + ), + ) + else: + return cast( + task.Task[Any], + self._ctx.call_sub_orchestrator( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + ), + ) + else: + if (supports_metadata and self._metadata is not None) or ( + supports_app_id and self._app_id is not None + ): + kwargs2: Dict[str, Any] = { + "input": self._input, + "instance_id": self._instance_id, + "retry_policy": self._retry_policy, + } + if supports_metadata and self._metadata is not None: + kwargs2["metadata"] = self._metadata + if supports_app_id and self._app_id is not None: + kwargs2["app_id"] = self._app_id + return cast( + task.Task[Any], + self._ctx.call_sub_orchestrator( + self._workflow_fn, + **kwargs2, + ), + ) + else: + return cast( + task.Task[Any], + self._ctx.call_sub_orchestrator( + self._workflow_fn, + input=self._input, + instance_id=self._instance_id, + retry_policy=self._retry_policy, + ), + ) + + +class SleepAwaitable(AwaitableBase[None]): + """Awaitable for timer/sleep operations.""" + + __slots__ = ("_ctx", "_duration") + + def __init__(self, ctx: Any, duration: Union[float, timedelta, datetime]): + """ + Initialize a sleep awaitable. + + Args: + ctx: The workflow context + duration: Sleep duration (seconds, timedelta, or absolute datetime) + """ + super().__init__() + self._ctx = ctx + self._duration = duration + + def _to_task(self) -> task.Task[Any]: + """Convert to a create_timer task.""" + # Convert numeric durations to timedelta objects + fire_at: Union[datetime, timedelta] + if isinstance(self._duration, (int, float)): + fire_at = timedelta(seconds=float(self._duration)) + else: + fire_at = self._duration + return cast(task.Task[Any], self._ctx.create_timer(fire_at)) + + +class ExternalEventAwaitable(AwaitableBase[TOutput]): + """Awaitable for external event operations.""" + + __slots__ = ("_ctx", "_name") + + def __init__(self, ctx: Any, name: str): + """ + Initialize an external event awaitable. + + Args: + ctx: The workflow context + name: Name of the external event to wait for + """ + super().__init__() + self._ctx = ctx + self._name = name + + def _to_task(self) -> task.Task[Any]: + """Convert to a wait_for_external_event task.""" + return cast(task.Task[Any], self._ctx.wait_for_external_event(self._name)) + + +class WhenAllAwaitable(AwaitableBase[List[TOutput]]): + """Awaitable for when_all operations (wait for all tasks to complete). + + Adds: + - Empty fast-path: returns [] without creating a task + - Multi-await safety: caches the result/exception for repeated awaits + """ + + __slots__ = ("_tasks_like", "_cached_result", "_cached_exception") + + def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]]]): + super().__init__() + self._tasks_like = list(tasks_like) + self._cached_result: Optional[List[Any]] = None + self._cached_exception: Optional[BaseException] = None + + def _to_task(self) -> task.Task[Any]: + """Convert to a when_all task.""" + # Empty fast-path: no durable task required + if len(self._tasks_like) == 0: + # Create a trivial completed task-like by when_all([]) + return cast(task.Task[Any], task.when_all([])) + underlying: List[task.Task[Any]] = [] + for a in self._tasks_like: + if isinstance(a, AwaitableBase): + underlying.append(a._to_task()) + elif isinstance(a, task.Task): + underlying.append(a) + else: + raise TypeError("when_all expects AwaitableBase or durabletask.task.Task") + return cast(task.Task[Any], task.when_all(underlying)) + + def __await__(self) -> Generator[Any, Any, List[TOutput]]: + if self._cached_exception is not None: + raise self._cached_exception + if self._cached_result is not None: + return cast(List[TOutput], self._cached_result) + # Empty fast-path: return [] immediately + if len(self._tasks_like) == 0: + self._cached_result = [] + return cast(List[TOutput], self._cached_result) + t = self._to_task() + try: + results = yield t + # Cache and return (ensure list) + self._cached_result = list(results) if isinstance(results, list) else [results] + return cast(List[TOutput], self._cached_result) + except BaseException as e: # noqa: BLE001 + self._cached_exception = e + raise + + +class WhenAnyAwaitable(AwaitableBase[tuple[int, Any]]): + """Awaitable for when_any operations (wait for any task to complete). + + Returns a tuple of (index, result) where index is the position of the completed task. + """ + + __slots__ = ("_originals", "_underlying") + + def __init__(self, tasks_like: Iterable[Union[AwaitableBase[Any], task.Task[Any]]]): + """ + Initialize a when_any awaitable. + + Args: + tasks_like: Iterable of awaitables or tasks to wait for + """ + super().__init__() + self._originals = list(tasks_like) + # Defer conversion to avoid issues with incomplete mocks and coroutine reuse + self._underlying: Optional[List[task.Task[Any]]] = None + + def _ensure_underlying(self) -> List[task.Task[Any]]: + """Lazily convert originals to tasks, caching the result.""" + if self._underlying is None: + self._underlying = [] + for a in self._originals: + if isinstance(a, AwaitableBase): + self._underlying.append(a._to_task()) + elif isinstance(a, task.Task): + self._underlying.append(a) + else: + raise TypeError("when_any expects AwaitableBase or durabletask.task.Task") + return self._underlying + + def _to_task(self) -> task.Task[Any]: + """Convert to a when_any task.""" + return cast(task.Task[Any], task.when_any(self._ensure_underlying())) + + def __await__(self) -> Generator[Any, Any, tuple[int, Any]]: + """Return (index, result) tuple of the first completed task.""" + underlying = self._ensure_underlying() + when_any_task = task.when_any(underlying) + completed_task = yield when_any_task + + # The completed_task should match one of our underlying tasks + for i, underlying_task in enumerate(underlying): + if underlying_task == completed_task: + result = ( + completed_task.get_result() if hasattr(completed_task, "get_result") else None + ) + return (i, result) + + # Fallback: return the completed task result with index 0 + result = completed_task.get_result() if hasattr(completed_task, "get_result") else None + return (0, result) + + +class TimeoutAwaitable(AwaitableBase[TOutput]): + """ + Awaitable that adds timeout functionality to any other awaitable. + + Raises TimeoutError if the operation doesn't complete within the specified time. + """ + + __slots__ = ("_awaitable", "_timeout_seconds", "_timeout", "_ctx", "_timeout_task") + + def __init__(self, awaitable: AwaitableBase[TOutput], timeout_seconds: float, ctx: Any): + """ + Initialize a timeout awaitable. + + Args: + awaitable: The awaitable to add timeout to + timeout_seconds: Timeout in seconds + ctx: The workflow context (needed for timer creation) + """ + super().__init__() + self._awaitable = awaitable + self._timeout_seconds = timeout_seconds + self._timeout = timeout_seconds # Alias for compatibility + self._ctx = ctx + self._timeout_task: Optional[task.Task[Any]] = None + + def _to_task(self) -> task.Task[Any]: + """Convert to a when_any between the operation and a timeout timer.""" + operation_task = self._awaitable._to_task() + # Cache the timeout task instance so __await__ compares against the same object + if self._timeout_task is None: + self._timeout_task = cast( + task.Task[Any], self._ctx.create_timer(timedelta(seconds=self._timeout_seconds)) + ) + return cast(task.Task[Any], task.when_any([operation_task, self._timeout_task])) + + def __await__(self) -> Generator[Any, Any, TOutput]: + """Override to handle timeout logic.""" + task_obj = self._to_task() + completed_task = yield task_obj + # If runtime provided a sentinel instead of a Task, decide heuristically + if not isinstance(completed_task, task.Task): + # Dicts, lists, tuples, and simple primitives are considered operation results + if isinstance(completed_task, (dict, list, tuple, str, int, float, bool, type(None))): + return cast(TOutput, completed_task) + # Otherwise, treat as timeout (e.g., mocks or opaque sentinels) + from .errors import WorkflowTimeoutError + + raise WorkflowTimeoutError( + timeout_seconds=self._timeout_seconds, + operation=str(self._awaitable.__class__.__name__), + ) + + # Check if it was the timeout that completed (compare to cached instance) + if self._timeout_task is not None and completed_task == self._timeout_task: + from .errors import WorkflowTimeoutError + + raise WorkflowTimeoutError( + timeout_seconds=self._timeout_seconds, + operation=str(self._awaitable.__class__.__name__), + ) + + # Return the actual result + return cast(TOutput, completed_task.result if hasattr(completed_task, "result") else None) + + +class SwallowExceptionAwaitable(AwaitableBase[Any]): + """ + Awaitable that swallows exceptions and returns them as values. + + This is useful for gather operations with return_exceptions=True. + """ + + __slots__ = ("_awaitable",) + + def __init__(self, awaitable: AwaitableBase[Any]): + """ + Initialize a swallow exception awaitable. + + Args: + awaitable: The awaitable to wrap + """ + super().__init__() + self._awaitable = awaitable + + def _to_task(self) -> task.Task[Any]: + """Convert to the underlying task.""" + return self._awaitable._to_task() + + def __await__(self) -> Generator[Any, Any, Any]: + """Override to catch and return exceptions.""" + try: + t = self._to_task() + result = yield t + return result + except Exception as e: # noqa: BLE001 + return e + + +# Utility functions for working with awaitables + + +def _resolve_callable(module_name: str, qualname: str) -> Callable[..., Any]: + """ + Resolve a callable from module name and qualified name. + + This is used internally for gather operations that need to serialize + and deserialize callable references. + """ + mod = importlib.import_module(module_name) + obj: Any = mod + for part in qualname.split("."): + obj = getattr(obj, part) + if not callable(obj): + raise TypeError(f"resolved object {module_name}.{qualname} is not callable") + return cast(Callable[..., Any], obj) diff --git a/durabletask/aio/compatibility.py b/durabletask/aio/compatibility.py new file mode 100644 index 0000000..1f22b25 --- /dev/null +++ b/durabletask/aio/compatibility.py @@ -0,0 +1,153 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Compatibility protocol for AsyncWorkflowContext. + +This module provides the core protocol definition that AsyncWorkflowContext +must implement to maintain compatibility with OrchestrationContext. +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Dict, Optional, Protocol, Union, runtime_checkable + +from durabletask import task + + +@runtime_checkable +class OrchestrationContextProtocol(Protocol): + """ + Protocol defining the interface that AsyncWorkflowContext must maintain + for compatibility with OrchestrationContext. + + This protocol ensures that AsyncWorkflowContext provides all the essential + properties and methods expected by the base OrchestrationContext interface. + """ + + # Core properties + @property + def instance_id(self) -> str: + """Get the ID of the current orchestration instance.""" + ... + + @property + def current_utc_datetime(self) -> datetime: + """Get the current date/time as UTC.""" + ... + + @property + def is_replaying(self) -> bool: + """Get whether the orchestrator is replaying from history.""" + ... + + @property + def workflow_name(self) -> Optional[str]: + """Get the orchestrator name/type for this instance.""" + ... + + @property + def is_suspended(self) -> bool: + """Get whether this orchestration is currently suspended.""" + ... + + # Core methods + def set_custom_status(self, custom_status: Any) -> None: + """Set the orchestration instance's custom status.""" + ... + + def create_timer(self, fire_at: Union[datetime, timedelta]) -> Any: + """Create a Timer Task to fire at the specified deadline.""" + ... + + def call_activity( + self, + activity: Union[task.Activity[Any, Any], str], + *, + input: Optional[Any] = None, + retry_policy: Optional[task.RetryPolicy] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> Any: + """Schedule an activity for execution.""" + ... + + def call_sub_orchestrator( + self, + orchestrator: Union[task.Orchestrator[Any, Any], str], + *, + input: Optional[Any] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> Any: + """Schedule sub-orchestrator function for execution.""" + ... + + def wait_for_external_event(self, name: str) -> Any: + """Wait asynchronously for an event to be raised.""" + ... + + def continue_as_new(self, new_input: Any) -> None: + """Continue the orchestration execution as a new instance.""" + ... + + +def ensure_compatibility(context_class: type) -> type: + """ + Decorator to ensure a context class maintains OrchestrationContext compatibility. + + This is a lightweight decorator that performs basic structural validation + at class definition time. + + Args: + context_class: The context class to validate + + Returns: + The same class (for use as decorator) + + Raises: + TypeError: If the class doesn't implement required protocol + """ + # Basic structural check - ensure required attributes exist + required_properties = [ + "instance_id", + "current_utc_datetime", + "is_replaying", + "workflow_name", + "is_suspended", + ] + + required_methods = [ + "set_custom_status", + "create_timer", + "call_activity", + "call_sub_orchestrator", + "wait_for_external_event", + "continue_as_new", + ] + + missing_items = [] + + for prop_name in required_properties: + if not hasattr(context_class, prop_name): + missing_items.append(f"property: {prop_name}") + + for method_name in required_methods: + if not hasattr(context_class, method_name): + missing_items.append(f"method: {method_name}") + + if missing_items: + raise TypeError( + f"{context_class.__name__} does not implement OrchestrationContextProtocol. " + f"Missing: {', '.join(missing_items)}" + ) + + return context_class diff --git a/durabletask/aio/context.py b/durabletask/aio/context.py new file mode 100644 index 0000000..030a95e --- /dev/null +++ b/durabletask/aio/context.py @@ -0,0 +1,380 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generic async workflow context for DurableTask workflows. + +This module provides a generic AsyncWorkflowContext that can be used across +different SDK implementations, providing a consistent async interface for +workflow operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime, timedelta +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast + +from durabletask import task as dt_task +from durabletask.deterministic import DeterministicContextMixin + +from .awaitables import ( + ActivityAwaitable, + AwaitableBase, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + TimeoutAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) +from .compatibility import ensure_compatibility + +# Generic type variable for awaitable result (module-level) +T = TypeVar("T") + + +@ensure_compatibility +class AsyncWorkflowContext(DeterministicContextMixin): + """ + Generic async workflow context providing a consistent interface for workflow operations. + + This context wraps a base DurableTask OrchestrationContext and provides async-friendly + methods for common workflow operations like calling activities, creating timers, + waiting for external events, and coordinating multiple operations. + """ + + __slots__ = ( + "_base_ctx", + "_rng", + "_debug_mode", + "_operation_history", + "_cleanup_tasks", + "_detection_disabled", + "_workflow_name", + "_current_step", + "_sandbox_originals", + "_sandbox_mode", + "_uuid_counter", + "_timestamp_counter", + ) + + # Generic type variable for awaitable result + def __init__(self, base_ctx: dt_task.OrchestrationContext): + """ + Initialize the async workflow context. + + Args: + base_ctx: The underlying DurableTask OrchestrationContext + """ + super().__init__() + self._base_ctx = base_ctx + self._rng = None + self._debug_mode = os.getenv("DAPR_WF_DEBUG") == "true" or os.getenv("DT_DEBUG") == "true" + self._operation_history: list[Dict[str, Any]] = [] + self._cleanup_tasks: list[Callable[[], Any]] = [] + self._workflow_name: Optional[str] = None + self._current_step: Optional[str] = None + # Set by sandbox when active + self._sandbox_originals: Optional[Dict[str, Any]] = None + self._sandbox_mode: Optional[str] = None + + # Performance optimization: Check if detection should be globally disabled + self._detection_disabled = os.getenv("DAPR_WF_DISABLE_DETERMINISTIC_DETECTION") == "true" + + # Core properties from base context + @property + def instance_id(self) -> str: + """Get the workflow instance ID.""" + return self._base_ctx.instance_id + + @property + def current_utc_datetime(self) -> datetime: + """Get the current orchestration time.""" + return self._base_ctx.current_utc_datetime + + @property + def is_replaying(self) -> bool: + """Check if the workflow is currently replaying.""" + return self._base_ctx.is_replaying + + @property + def is_suspended(self) -> bool: + """Check if the workflow is currently suspended.""" + return self._base_ctx.is_suspended + + @property + def workflow_name(self) -> Optional[str]: + """Get the workflow name.""" + return getattr(self._base_ctx, "workflow_name", None) + + # Activity operations + def call_activity( + self, + activity_fn: Union[dt_task.Activity[Any, Any], str], + *, + input: Any = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> ActivityAwaitable[Any]: + """ + Create an awaitable for calling an activity function. + + Args: + activity_fn: The activity function or name to call + input: Input data for the activity + retry_policy: Optional retry policy + metadata: Optional metadata for the activity call + + Returns: + An awaitable that will complete when the activity finishes + """ + self._log_operation("activity", {"function": str(activity_fn), "input": input}) + return ActivityAwaitable( + self._base_ctx, + cast(Callable[..., Any], activity_fn), + input=input, + retry_policy=retry_policy, + app_id=app_id, + metadata=metadata, + ) + + # Sub-orchestrator operations + def sub_orchestrator( + self, + workflow_fn: Union[dt_task.Orchestrator[Any, Any], str], + *, + input: Any = None, + instance_id: Optional[str] = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> SubOrchestratorAwaitable[Any]: + """ + Create an awaitable for calling a sub-orchestrator. + + Args: + workflow_fn: The sub-orchestrator function or name to call + input: Input data for the sub-orchestrator + instance_id: Optional instance ID for the sub-orchestrator + retry_policy: Optional retry policy + metadata: Optional metadata for the sub-orchestrator call + + Returns: + An awaitable that will complete when the sub-orchestrator finishes + """ + self._log_operation( + "sub_orchestrator", + {"function": str(workflow_fn), "input": input, "instance_id": instance_id}, + ) + return SubOrchestratorAwaitable( + self._base_ctx, + cast(Callable[..., Any], workflow_fn), + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + app_id=app_id, + metadata=metadata, + ) + + def call_sub_orchestrator( + self, + workflow_fn: Union[dt_task.Orchestrator[Any, Any], str], + *, + input: Any = None, + instance_id: Optional[str] = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ) -> SubOrchestratorAwaitable[Any]: + """Call a sub-orchestrator workflow (durabletask naming convention).""" + return self.sub_orchestrator( + workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + app_id=app_id, + metadata=metadata, + ) + + def create_timer(self, duration: Union[float, timedelta, datetime]) -> SleepAwaitable: + """ + Create an awaitable for sleeping/waiting. + + Args: + duration: Sleep duration (seconds, timedelta, or absolute datetime) + + Returns: + An awaitable that will complete after the specified duration + """ + self._log_operation("sleep", {"duration": duration}) + return SleepAwaitable(self._base_ctx, duration) + + # External event operations + def wait_for_external_event(self, name: str) -> ExternalEventAwaitable[Any]: + """ + Create an awaitable for waiting for an external event. + + Args: + name: Name of the external event to wait for + + Returns: + An awaitable that will complete when the external event is received + """ + self._log_operation("wait_for_external_event", {"name": name}) + return ExternalEventAwaitable(self._base_ctx, name) + + # Coordination operations + def when_all(self, awaitables: List[Any]) -> WhenAllAwaitable[Any]: + """ + Create an awaitable that completes when all provided awaitables complete. + + Args: + awaitables: List of awaitables to wait for + + Returns: + An awaitable that will complete with a list of all results + """ + self._log_operation("when_all", {"count": len(awaitables)}) + return WhenAllAwaitable(awaitables) + + def when_any(self, awaitables: List[Any]) -> WhenAnyAwaitable: + """ + Create an awaitable that completes when any of the provided awaitables completes. + + Args: + awaitables: List of awaitables to wait for + + Returns: + An awaitable that will complete with (index, result) tuple where index is the + position of the first completed awaitable in the input list + """ + self._log_operation("when_any", {"count": len(awaitables)}) + return WhenAnyAwaitable(awaitables) + + # Enhanced operations + def with_timeout(self, awaitable: "AwaitableBase[T]", timeout: float) -> TimeoutAwaitable[T]: + """ + Add timeout functionality to any awaitable. + + Args: + awaitable: The awaitable to add timeout to + timeout: Timeout in seconds + + Returns: + An awaitable that will raise TimeoutError if not completed within timeout + """ + self._log_operation("with_timeout", {"timeout": timeout}) + return TimeoutAwaitable(awaitable, float(timeout), self._base_ctx) + + # Custom status operations + def set_custom_status(self, status: Any) -> None: + """ + Set custom status for the workflow instance. + + Args: + status: Custom status object + """ + if hasattr(self._base_ctx, "set_custom_status"): + self._base_ctx.set_custom_status(status) + self._log_operation("set_custom_status", {"status": status}) + + def continue_as_new(self, input_data: Any = None, *, save_events: bool = False) -> None: + """ + Continue the workflow as new with optional new input. + + Args: + input_data: Optional new input data + save_events: Whether to save events (matches base durabletask API) + """ + self._log_operation("continue_as_new", {"input": input_data, "save_events": save_events}) + + if hasattr(self._base_ctx, "continue_as_new"): + # For compatibility with mocks/tests expecting positional-only when default is used, + # call without the keyword when save_events is False; otherwise pass explicitly. + if save_events is False: + self._base_ctx.continue_as_new(input_data) + else: + self._base_ctx.continue_as_new(input_data, save_events=save_events) + + # Enhanced context management + async def __aenter__(self) -> "AsyncWorkflowContext": + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + """Async context manager exit with cleanup.""" + # Run cleanup tasks in reverse order (LIFO) + for cleanup_task in reversed(self._cleanup_tasks): + try: + result = cleanup_task() + # If the cleanup returns an awaitable, await it + try: + import inspect as _inspect + + if _inspect.isawaitable(result): + await result + except Exception: + # If inspection fails, ignore and continue + pass + except Exception as e: + if self._debug_mode: + print(f"[WORKFLOW DEBUG] Cleanup task failed: {e}") + + self._cleanup_tasks.clear() + + # Debug and monitoring + def _log_operation(self, operation: str, details: Dict[str, Any]) -> None: + """Log workflow operation for debugging.""" + if self._debug_mode: + entry = { + "type": operation, # Use "type" for compatibility + "operation": operation, + "details": details, + "sequence": len(self._operation_history), + "timestamp": self.current_utc_datetime.isoformat(), + "is_replaying": self.is_replaying, + } + self._operation_history.append(entry) + print(f"[WORKFLOW DEBUG] {operation}: {details}") + + def _get_info_snapshot(self) -> Dict[str, Any]: + """ + Get debug information about the workflow execution. + + Returns: + Dictionary containing debug information + """ + return { + "instance_id": self.instance_id, + "current_time": self.current_utc_datetime.isoformat(), + "is_replaying": self.is_replaying, + "is_suspended": self.is_suspended, + "operation_history": self._operation_history.copy(), + "cleanup_tasks_count": len(self._cleanup_tasks), + "debug_mode": self._debug_mode, + "detection_disabled": self._detection_disabled, + } + + def __repr__(self) -> str: + """String representation of the context.""" + return ( + f"AsyncWorkflowContext(" + f"instance_id='{self.instance_id}', " + f"is_replaying={self.is_replaying}, " + f"operations={len(self._operation_history)})" + ) diff --git a/durabletask/aio/driver.py b/durabletask/aio/driver.py new file mode 100644 index 0000000..7510e6f --- /dev/null +++ b/durabletask/aio/driver.py @@ -0,0 +1,353 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Driver for async workflow orchestrators in durabletask.aio. + +This module provides the CoroutineOrchestratorRunner that bridges async/await +syntax with the generator-based DurableTask runtime, ensuring proper replay +semantics and deterministic execution. +""" + +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Generator +from typing import Any, Callable, Optional, Protocol, TypeVar, cast, runtime_checkable + +from durabletask import task +from durabletask.aio.errors import AsyncWorkflowError, WorkflowValidationError +from durabletask.aio.sandbox import SandboxMode + +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") + + +@runtime_checkable +class WorkflowFunction(Protocol): + """Protocol for workflow functions.""" + + async def __call__(self, ctx: Any, input_data: Optional[Any] = None) -> Any: ... + + +class CoroutineOrchestratorRunner: + """ + Wraps an async orchestrator function into a generator-compatible runner. + + This class bridges the gap between async/await syntax and the generator-based + DurableTask runtime, enabling developers to write workflows using modern + async Python while maintaining deterministic execution semantics. + + The implementation uses an iterator pattern to properly handle replay scenarios + and avoid coroutine reuse issues that can occur during workflow replay. + """ + + __slots__ = ("_async_orchestrator", "_sandbox_mode", "_workflow_name") + + def __init__( + self, + async_orchestrator: Callable[..., Awaitable[Any]], + *, + sandbox_mode: str = "best_effort", + workflow_name: Optional[str] = None, + ): + """ + Initialize the coroutine orchestrator runner. + + Args: + async_orchestrator: The async workflow function to wrap + sandbox_mode: Sandbox mode ('off', 'best_effort', 'strict'). Default: 'best_effort' + workflow_name: Optional workflow name for error reporting + """ + self._async_orchestrator = async_orchestrator + self._sandbox_mode = sandbox_mode + name_attr = getattr(async_orchestrator, "__name__", None) + base_name: str = name_attr if isinstance(name_attr, str) else "unknown" + self._workflow_name: str = workflow_name if workflow_name is not None else base_name + self._validate_orchestrator(async_orchestrator) + + def _validate_orchestrator(self, orchestrator_fn: Callable[..., Awaitable[Any]]) -> None: + """ + Validate that the orchestrator function is suitable for async workflows. + + Args: + orchestrator_fn: The function to validate + + Raises: + WorkflowValidationError: If the function is not valid + """ + if not callable(orchestrator_fn): + raise WorkflowValidationError( + "Orchestrator must be callable", + validation_type="function_type", + workflow_name=self._workflow_name, + ) + + if not inspect.iscoroutinefunction(orchestrator_fn): + raise WorkflowValidationError( + "Orchestrator must be an async function (defined with 'async def')", + validation_type="async_function", + workflow_name=self._workflow_name, + ) + + # Check function signature + sig = inspect.signature(orchestrator_fn) + params = list(sig.parameters.values()) + + if len(params) < 1: + raise WorkflowValidationError( + "Orchestrator must accept at least one parameter (context)", + validation_type="function_signature", + workflow_name=self._workflow_name, + ) + + if len(params) > 2: + raise WorkflowValidationError( + "Orchestrator must accept at most two parameters (context, input)", + validation_type="function_signature", + workflow_name=self._workflow_name, + ) + + def to_generator( + self, async_ctx: Any, input_data: Optional[Any] = None + ) -> Generator[task.Task[Any], Any, Any]: + """ + Convert the async orchestrator to a generator that the DurableTask runtime can drive. + + This implementation uses an iterator pattern similar to the original to properly + handle replay scenarios and avoid coroutine reuse issues. + + Args: + async_ctx: The async workflow context + input_data: Optional input data for the workflow + + Returns: + A generator that yields tasks and receives results + + Raises: + AsyncWorkflowError: If there are issues during workflow execution + """ + # Import sandbox here to avoid circular imports + from .sandbox import _sandbox_scope + + def driver_gen() -> Generator[task.Task[Any], Any, Any]: + """Inner generator that drives the coroutine execution.""" + # Instantiate the coroutine with appropriate parameters + try: + sig = inspect.signature(self._async_orchestrator) + params = list(sig.parameters.values()) + + if len(params) == 1: + # Single parameter (context only) + coro = self._async_orchestrator(async_ctx) + else: + # Two parameters (context and input) + coro = self._async_orchestrator(async_ctx, input_data) + + except TypeError as e: + raise AsyncWorkflowError( + f"Failed to instantiate workflow coroutine: {e}", + workflow_name=self._workflow_name, + step="initialization", + ) from e + + # Prime the coroutine to first await point or finish synchronously + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited_obj = cast(Any, coro).send(None) + else: + with _sandbox_scope(async_ctx, self._sandbox_mode): + awaited_obj = cast(Any, coro).send(None) + except StopIteration as stop: + return stop.value + except Exception as e: + # Close the coroutine to avoid "never awaited" warning + coro.close() + # Re-raise NonRetryableError directly to preserve its type for the runtime + if isinstance(e, task.NonRetryableError): + raise + raise AsyncWorkflowError( + f"Workflow failed during initialization: {e}", + workflow_name=self._workflow_name, + instance_id=getattr(async_ctx, "instance_id", None), + step="initialization", + ) from e + + def to_iter(obj: Any) -> Generator[Any, Any, Any]: + if hasattr(obj, "__await__"): + return cast(Generator[Any, Any, Any], obj.__await__()) + if isinstance(obj, task.Task): + # Wrap a single Task into a one-shot awaitable iterator + def _one_shot() -> Generator[task.Task[Any], Any, Any]: + res = yield obj + return res + + return _one_shot() + raise AsyncWorkflowError( + f"Async orchestrator awaited unsupported object type: {type(obj)}", + workflow_name=self._workflow_name, + step="awaitable_conversion", + ) + + awaited_iter = to_iter(awaited_obj) + while True: + # Advance the awaitable to a DT Task to yield + try: + request = awaited_iter.send(None) + except StopIteration as stop_await: + # Awaitable finished synchronously; feed result back to coroutine + try: + if self._sandbox_mode == "off": + awaited_obj = cast(Any, coro).send(stop_await.value) + else: + with _sandbox_scope(async_ctx, self._sandbox_mode): + awaited_obj = cast(Any, coro).send(stop_await.value) + except StopIteration as stop: + return stop.value + except Exception as e: + # Close the coroutine to avoid "never awaited" warning + coro.close() + # Check if this is a TaskFailedError wrapping a NonRetryableError + if isinstance(e, task.TaskFailedError): + details = e.details + if details.error_type == "NonRetryableError": + # Reconstruct NonRetryableError to preserve its type for the runtime + raise task.NonRetryableError(details.message) from e + # Re-raise NonRetryableError directly to preserve its type for the runtime + if isinstance(e, task.NonRetryableError): + raise + raise AsyncWorkflowError( + f"Workflow failed: {e}", + workflow_name=self._workflow_name, + step="execution", + ) from e + awaited_iter = to_iter(awaited_obj) + continue + + if not isinstance(request, task.Task): + raise AsyncWorkflowError( + f"Async awaitable yielded a non-Task object: {type(request)}", + workflow_name=self._workflow_name, + step="execution", + ) + + # Yield to runtime and resume awaitable with task result + try: + result = yield request + except Exception as e: + # Route exception into awaitable first; if it completes, continue; otherwise forward to coroutine + try: + awaited_iter.throw(e) + except StopIteration as stop_await: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited_obj = cast(Any, coro).send(stop_await.value) + else: + with _sandbox_scope(async_ctx, self._sandbox_mode): + awaited_obj = cast(Any, coro).send(stop_await.value) + except StopIteration as stop: + return stop.value + except Exception as workflow_exc: + # Close the coroutine to avoid "never awaited" warning + coro.close() + # Check if this is a TaskFailedError wrapping a NonRetryableError + if isinstance(workflow_exc, task.TaskFailedError): + details = workflow_exc.details + if details.error_type == "NonRetryableError": + # Reconstruct NonRetryableError to preserve its type for the runtime + raise task.NonRetryableError(details.message) from workflow_exc + # Re-raise NonRetryableError directly to preserve its type for the runtime + if isinstance(workflow_exc, task.NonRetryableError): + raise + raise AsyncWorkflowError( + f"Workflow failed: {workflow_exc}", + workflow_name=self._workflow_name, + step="execution", + ) from workflow_exc + awaited_iter = to_iter(awaited_obj) + except Exception as exc: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited_obj = cast(Any, coro).throw(exc) + else: + with _sandbox_scope(async_ctx, self._sandbox_mode): + awaited_obj = cast(Any, coro).throw(exc) + except StopIteration as stop: + return stop.value + except Exception as workflow_exc: + # Close the coroutine to avoid "never awaited" warning + coro.close() + # Check if this is a TaskFailedError wrapping a NonRetryableError + if isinstance(workflow_exc, task.TaskFailedError): + details = workflow_exc.details + if details.error_type == "NonRetryableError": + # Reconstruct NonRetryableError to preserve its type for the runtime + raise task.NonRetryableError(details.message) from workflow_exc + # Re-raise NonRetryableError directly to preserve its type for the runtime + if isinstance(workflow_exc, task.NonRetryableError): + raise + raise AsyncWorkflowError( + f"Workflow failed: {workflow_exc}", + workflow_name=self._workflow_name, + step="execution", + ) from workflow_exc + awaited_iter = to_iter(awaited_obj) + continue + + # Success: feed result to awaitable; it may yield more tasks until it stops + try: + next_req = awaited_iter.send(result) + while True: + if not isinstance(next_req, task.Task): + raise AsyncWorkflowError( + f"Async awaitable yielded a non-Task object: {type(next_req)}", + workflow_name=self._workflow_name, + step="execution", + ) + result = yield next_req + next_req = awaited_iter.send(result) + except StopIteration as stop_await: + try: + if self._sandbox_mode == SandboxMode.OFF: + awaited_obj = cast(Any, coro).send(stop_await.value) + else: + with _sandbox_scope(async_ctx, self._sandbox_mode): + awaited_obj = cast(Any, coro).send(stop_await.value) + except StopIteration as stop: + return stop.value + except Exception as e: + # Check if this is a TaskFailedError wrapping a NonRetryableError + if isinstance(e, task.TaskFailedError): + details = e.details + if details.error_type == "NonRetryableError": + # Reconstruct NonRetryableError to preserve its type for the runtime + raise task.NonRetryableError(details.message) from e + # Re-raise NonRetryableError directly to preserve its type for the runtime + if isinstance(e, task.NonRetryableError): + raise + raise AsyncWorkflowError( + f"Workflow failed: {e}", + workflow_name=self._workflow_name, + step="execution", + ) from e + awaited_iter = to_iter(awaited_obj) + + return driver_gen() + + @property + def workflow_name(self) -> str: + """Get the workflow name.""" + return self._workflow_name + + @property + def sandbox_mode(self) -> str: + """Get the sandbox mode.""" + return self._sandbox_mode diff --git a/durabletask/aio/errors.py b/durabletask/aio/errors.py new file mode 100644 index 0000000..5e859ea --- /dev/null +++ b/durabletask/aio/errors.py @@ -0,0 +1,146 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Enhanced error handling for async workflows. + +This module provides specialized exceptions for async workflow operations +with rich context information to aid in debugging and error handling. +""" + +from __future__ import annotations + +from typing import Any, Optional + + +class AsyncWorkflowError(Exception): + """Enhanced exception for async workflow issues with context information.""" + + def __init__( + self, + message: str, + *, + instance_id: Optional[str] = None, + step: Optional[str] = None, + workflow_name: Optional[str] = None, + ): + """ + Initialize an AsyncWorkflowError with context information. + + Args: + message: The error message + instance_id: The workflow instance ID where the error occurred + step: The workflow step/operation where the error occurred + workflow_name: The name of the workflow where the error occurred + """ + self.instance_id = instance_id + self.step = step + self.workflow_name = workflow_name + + context_parts = [] + if workflow_name: + context_parts.append(f"workflow: {workflow_name}") + if instance_id: + context_parts.append(f"instance: {instance_id}") + if step: + context_parts.append(f"step: {step}") + + context_str = f" ({', '.join(context_parts)})" if context_parts else "" + super().__init__(f"{message}{context_str}") + + +class NonDeterminismWarning(UserWarning): + """Warning raised when non-deterministic functions are detected in workflows.""" + + pass + + +class WorkflowTimeoutError(AsyncWorkflowError): + """Exception raised when a workflow operation times out.""" + + def __init__( + self, + message: str = "Operation timed out", + *, + timeout_seconds: Optional[float] = None, + operation: Optional[str] = None, + **kwargs: Any, + ): + """ + Initialize a WorkflowTimeoutError. + + Args: + message: The error message + timeout_seconds: The timeout value that was exceeded + operation: The operation that timed out + **kwargs: Additional context passed to AsyncWorkflowError + """ + self.timeout_seconds = timeout_seconds + self.operation = operation + + if timeout_seconds and operation: + message = f"{operation} timed out after {timeout_seconds}s" + elif timeout_seconds: + message = f"Operation timed out after {timeout_seconds}s" + elif operation: + message = f"{operation} timed out" + + super().__init__(message, **kwargs) + + +class WorkflowValidationError(AsyncWorkflowError): + """Exception raised when workflow validation fails.""" + + def __init__(self, message: str, *, validation_type: Optional[str] = None, **kwargs: Any): + """ + Initialize a WorkflowValidationError. + + Args: + message: The error message + validation_type: The type of validation that failed + **kwargs: Additional context passed to AsyncWorkflowError + """ + self.validation_type = validation_type + + if validation_type: + message = f"{validation_type} validation failed: {message}" + + super().__init__(message, **kwargs) + + +class SandboxViolationError(AsyncWorkflowError): + """Exception raised when sandbox restrictions are violated.""" + + def __init__( + self, + message: str, + *, + violation_type: Optional[str] = None, + suggested_alternative: Optional[str] = None, + **kwargs: Any, + ): + """ + Initialize a SandboxViolationError. + + Args: + message: The error message + violation_type: The type of sandbox violation + suggested_alternative: Suggested alternative approach + **kwargs: Additional context passed to AsyncWorkflowError + """ + self.violation_type = violation_type + self.suggested_alternative = suggested_alternative + + full_message = message + if suggested_alternative: + full_message += f". Consider using: {suggested_alternative}" + + super().__init__(full_message, **kwargs) diff --git a/durabletask/aio/sandbox.py b/durabletask/aio/sandbox.py new file mode 100644 index 0000000..a6552b8 --- /dev/null +++ b/durabletask/aio/sandbox.py @@ -0,0 +1,756 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sandbox for deterministic workflow execution. + +This module provides sandboxing capabilities that patch non-deterministic +Python functions with deterministic alternatives during workflow execution. +It also includes non-determinism detection to help developers identify +problematic code patterns. +""" + +from __future__ import annotations + +import contextlib +import os +import sys +import warnings +from contextlib import ContextDecorator +from datetime import timedelta +from enum import Enum +from types import FrameType +from typing import Any, Callable, Dict, Optional, Set, Type, Union, cast + +from durabletask.deterministic import deterministic_random, deterministic_uuid4 + +from .errors import NonDeterminismWarning, SandboxViolationError + +# Capture environment variable at module load to avoid triggering non-determinism detection +_DISABLE_DETECTION = os.getenv("DAPR_WF_DISABLE_DETERMINISTIC_DETECTION") == "true" + + +class SandboxMode(str, Enum): + """Sandbox mode options. + + Use as an alternative to string literals to avoid typos and enable IDE support. + """ + + OFF = "off" + BEST_EFFORT = "best_effort" + STRICT = "strict" + + @classmethod + def from_string(cls, mode: str) -> SandboxMode: + if mode not in cls.__members__.values(): + raise ValueError( + f"Invalid sandbox mode: {mode}. Must be one of: {cls.__members__.values()}." + ) + return cls(mode) + + +class _NonDeterminismDetector: + """Detects and warns about non-deterministic function calls in workflows.""" + + def __init__(self, async_ctx: Any, mode: Union[str, SandboxMode]): + self.async_ctx = async_ctx + self.mode = SandboxMode.from_string(mode) + self.detected_calls: Set[str] = set() + self.original_trace_func: Optional[Callable[[FrameType, str, Any], Any]] = None + self._restore_trace_func: Optional[Callable[[FrameType, str, Any], Any]] = None + self._active_trace_func: Optional[Callable[[FrameType, str, Any], Any]] = None + + def _noop_trace( + self, frame: FrameType, event: str, arg: Any + ) -> Optional[Callable[[FrameType, str, Any], Any]]: # lightweight tracer + return None + + def __enter__(self) -> "_NonDeterminismDetector": + enable_full_detection = self.mode == "strict" or ( + self.mode == "best_effort" and getattr(self.async_ctx, "_debug_mode", False) + ) + if self.mode in ("best_effort", "strict"): + self.original_trace_func = sys.gettrace() + # Use full detection tracer in strict or when debug mode is enabled + self._active_trace_func = ( + self._trace_calls if enable_full_detection else self._noop_trace + ) + sys.settrace(self._active_trace_func) + self._restore_trace_func = sys.gettrace() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self.mode in ("best_effort", "strict"): + # Restore to the trace function that was active before __enter__ + sys.settrace(self.original_trace_func) + + def _trace_calls( + self, frame: FrameType, event: str, arg: Any + ) -> Optional[Callable[[FrameType, str, Any], Any]]: + """Trace function calls to detect non-deterministic operations.""" + # Only handle function call events to minimize overhead + if event != "call": + return self.original_trace_func(frame, event, arg) if self.original_trace_func else None + + # Perform best-effort detection on call sites + self._check_frame_for_non_determinism(frame) + + # Do not install a per-frame local tracer; let original (if any) handle further events + return self.original_trace_func if self.original_trace_func else None + + def _check_frame_for_non_determinism(self, frame: FrameType) -> None: + """Check if the current frame contains non-deterministic function calls.""" + code = frame.f_code + filename = code.co_filename + func_name = code.co_name + + # Fast module/function check via globals to reduce overhead + try: + module_name = frame.f_globals.get("__name__", "") + except Exception: + module_name = "" + + if module_name: + fast_map = { + "datetime": {"now", "utcnow"}, + "time": {"time", "time_ns"}, + "random": {"random", "randint", "choice", "shuffle"}, + "uuid": {"uuid1", "uuid4"}, + "os": {"urandom", "getenv"}, + "secrets": {"token_bytes", "token_hex", "choice"}, + "socket": {"gethostname"}, + "platform": {"node"}, + "threading": {"current_thread"}, + } + funcs = fast_map.get(module_name) + if funcs and func_name in funcs: + # Whitelist deterministic RNG method calls bound to our patched RNG instance + if module_name == "random" and func_name in { + "random", + "randint", + "choice", + "shuffle", + }: + try: + bound_self = frame.f_locals.get("self") + if getattr(bound_self, "_dt_deterministic", False): + return + except Exception: + pass + self._handle_non_deterministic_call(f"{module_name}.{func_name}", frame) + if self.mode == "best_effort": + return + + # Skip our own code and system modules + if "durabletask" in filename or filename.startswith("<"): + return + + # Check for problematic function calls + non_deterministic_patterns = [ + ("datetime", "now"), + ("datetime", "utcnow"), + ("time", "time"), + ("time", "time_ns"), + ("random", "random"), + ("random", "randint"), + ("random", "choice"), + ("random", "shuffle"), + ("uuid", "uuid1"), + ("uuid", "uuid4"), + ("os", "urandom"), + ("os", "getenv"), + ("secrets", "token_bytes"), + ("secrets", "token_hex"), + ("secrets", "choice"), + ("socket", "gethostname"), + ("platform", "node"), + ("threading", "current_thread"), + ] + + if self.mode != "best_effort": + # Check local variables for module usage + for var_name, var_value in frame.f_locals.items(): + module_name = getattr(var_value, "__module__", None) + if module_name: + for pattern_module, pattern_func in non_deterministic_patterns: + if ( + pattern_module in module_name + and hasattr(var_value, pattern_func) + and func_name == pattern_func + ): + self._handle_non_deterministic_call( + f"{pattern_module}.{pattern_func}", frame + ) + + # Check for direct function calls in globals (guard against non-mapping f_globals) + try: + globals_map = frame.f_globals + except Exception: + globals_map = {} + for pattern_module, pattern_func in non_deterministic_patterns: + full_name = f"{pattern_module}.{pattern_func}" + try: + if ( + isinstance(globals_map, dict) + and full_name in globals_map + and func_name == pattern_func + ): + self._handle_non_deterministic_call(full_name, frame) + except Exception: + continue + + def _handle_non_deterministic_call(self, function_name: str, frame: FrameType) -> None: + """Handle detection of a non-deterministic function call.""" + if function_name in self.detected_calls: + return # Already reported + + self.detected_calls.add(function_name) + + # Get context information + code = frame.f_code + filename = code.co_filename + lineno = frame.f_lineno + func = code.co_name + + # Create detailed message with suggestions + suggestions = { + "datetime.now": "ctx.now()", + "datetime.utcnow": "ctx.now()", + "time.time": "ctx.now().timestamp()", + "time.time_ns": "int(ctx.now().timestamp() * 1_000_000_000)", + "random.random": "ctx.random().random()", + "random.randint": "ctx.random().randint()", + "random.choice": "ctx.random().choice()", + "uuid.uuid4": "ctx.uuid4()", + "os.urandom": "ctx.random().randbytes()", + "secrets.token_bytes": "ctx.random().randbytes()", + "secrets.token_hex": "ctx.random_string()", + } + + suggestion = suggestions.get(function_name, "a deterministic alternative") + message = ( + f"Non-deterministic function '{function_name}' detected at {filename}:{lineno} " + f"(in {func}). Consider using {suggestion} instead." + ) + + # Log debug information if enabled + if hasattr(self.async_ctx, "_debug_mode") and self.async_ctx._debug_mode: + print(f"[WORKFLOW DEBUG] {message}") + + if self.mode == "strict": + raise SandboxViolationError( + f"Non-deterministic function '{function_name}' is not allowed in strict mode", + violation_type="non_deterministic_call", + suggested_alternative=suggestion, + workflow_name=getattr(self.async_ctx, "_workflow_name", None), + instance_id=getattr(self.async_ctx, "instance_id", None), + ) + elif self.mode == "best_effort": + # Warn only once per function and do not escalate to error in best_effort + warnings.warn(message, NonDeterminismWarning, stacklevel=3) + + def _get_deterministic_alternative(self, function_name: str) -> str: + """Get deterministic alternative suggestion for a function.""" + suggestions = { + "datetime.now": "ctx.now()", + "datetime.utcnow": "ctx.now()", + "time.time": "ctx.now().timestamp()", + "time.time_ns": "int(ctx.now().timestamp() * 1_000_000_000)", + "random.random": "ctx.random().random()", + "random.randint": "ctx.random().randint()", + "random.choice": "ctx.random().choice()", + "random.shuffle": "ctx.random().shuffle()", + "uuid.uuid1": "ctx.uuid4() (deterministic)", + "uuid.uuid4": "ctx.uuid4()", + "os.urandom": "ctx.random().randbytes() or ctx.random().getrandbits()", + "secrets.token_bytes": "ctx.random().randbytes()", + "secrets.token_hex": "ctx.random().randbytes().hex()", + "socket.gethostname": "hardcoded hostname or activity call", + "threading.current_thread": "avoid threading in workflows", + } + return suggestions.get(function_name, "a deterministic alternative") + + +class _Sandbox(ContextDecorator): + """Context manager for sandboxing workflow execution.""" + + def __init__(self, async_ctx: Any, mode: Union[str, SandboxMode]): + self.async_ctx = async_ctx + self.mode = SandboxMode.from_string(mode) + self.originals: Dict[str, Any] = {} + self.detector: Optional[_NonDeterminismDetector] = None + + def __enter__(self) -> "_Sandbox": + if self.mode == SandboxMode.OFF: + return self + + # Check for global disable + if getattr(self.async_ctx, "_detection_disabled", False): + return self + + # Enable non-determinism detection + self.detector = _NonDeterminismDetector(self.async_ctx, self.mode) + self.detector.__enter__() + + # Apply patches for best_effort and strict modes + self._apply_patches() + + # Expose originals/mode to the async workflow context for controlled unsafe access + try: + self.async_ctx._sandbox_originals = dict(self.originals) + self.async_ctx._sandbox_mode = self.mode + except Exception: + # Context may not support attribute assignment; ignore + pass + + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self.detector: + self.detector.__exit__(exc_type, exc_val, exc_tb) + + if self.mode != SandboxMode.OFF and self.originals: + self._restore_originals() + + # Remove exposed references from the async context + try: + if hasattr(self.async_ctx, "_sandbox_originals"): + delattr(self.async_ctx, "_sandbox_originals") + if hasattr(self.async_ctx, "_sandbox_mode"): + delattr(self.async_ctx, "_sandbox_mode") + except Exception: + pass + + def _apply_patches(self) -> None: + """Apply patches to non-deterministic functions.""" + import asyncio as _asyncio + import datetime as _datetime + import random as _random + import time as _time_mod + import uuid as _uuid_mod + + # Store originals for restoration + self.originals = { + "asyncio.sleep": _asyncio.sleep, + "asyncio.gather": getattr(_asyncio, "gather", None), + "asyncio.create_task": getattr(_asyncio, "create_task", None), + "random.random": _random.random, + "random.randrange": _random.randrange, + "random.randint": _random.randint, + "random.getrandbits": _random.getrandbits, + "uuid.uuid4": _uuid_mod.uuid4, + "time.time": _time_mod.time, + "time.time_ns": getattr(_time_mod, "time_ns", None), + "datetime.now": _datetime.datetime.now, + "datetime.utcnow": _datetime.datetime.utcnow, + } + + # Add strict mode blocks for potentially dangerous operations + if self.mode == SandboxMode.STRICT: + import builtins + import os as _os + import secrets as _secrets + + self.originals.update( + { + "builtins.open": builtins.open, + "os.urandom": getattr(_os, "urandom", None), + "secrets.token_bytes": getattr(_secrets, "token_bytes", None), + "secrets.token_hex": getattr(_secrets, "token_hex", None), + } + ) + + # Create patched functions + def patched_sleep(delay: Union[float, int]) -> Any: + # Capture the context in the closure + base_ctx = self.async_ctx._base_ctx + + class _PatchedSleepAwaitable: + def __await__(self) -> Any: + result = yield base_ctx.create_timer(timedelta(seconds=float(delay))) + return result + + # Pass through zero-or-negative delays to the original asyncio.sleep + try: + if float(delay) <= 0: + orig_sleep = self.originals.get("asyncio.sleep") + if orig_sleep is not None: + return orig_sleep(0) # return the original coroutine + except Exception: + pass + + return _PatchedSleepAwaitable() + + # Derive RNG from instance/time to make results deterministic per-context + # Fallbacks ensure this works with plain mocks used in tests + iid = getattr(self.async_ctx, "instance_id", None) + if iid is None: + base = getattr(self.async_ctx, "_base_ctx", None) + iid = getattr(base, "instance_id", "") if base is not None else "" + now_dt = None + if hasattr(self.async_ctx, "now"): + try: + now_dt = self.async_ctx.now() + except Exception: + now_dt = None + if now_dt is None: + if hasattr(self.async_ctx, "current_utc_datetime"): + now_dt = self.async_ctx.current_utc_datetime + else: + base = getattr(self.async_ctx, "_base_ctx", None) + now_dt = getattr(base, "current_utc_datetime", None) if base is not None else None + if now_dt is None: + now_dt = _datetime.datetime.fromtimestamp(0, _datetime.timezone.utc) + rng = deterministic_random(iid or "", now_dt) + # Mark as deterministic so the detector can whitelist bound method calls + try: + rng._dt_deterministic = True + except Exception: + pass + + def patched_random() -> float: + return rng.random() + + def patched_randrange( + start: int, + stop: Optional[int] = None, + step: int = 1, + _int: Callable[[float], int] = int, + ) -> int: + # Deterministic randrange using rng + if stop is None: + start, stop = 0, start + assert stop is not None + width = stop - start + if step == 1 and width > 0: + return start + _int(rng.random() * width) + # Fallback: generate until fits + while True: + n = start + _int(rng.random() * width) + if (n - start) % step == 0: + return n + + def patched_getrandbits(k: int) -> int: + return rng.getrandbits(k) + + def patched_randint(a: int, b: int) -> int: + return rng.randint(a, b) + + def patched_uuid4() -> Any: + return deterministic_uuid4(rng) + + def patched_time() -> float: + dt = self.async_ctx.now() + return float(dt.timestamp()) + + def patched_time_ns() -> int: + dt = self.async_ctx.now() + return int(dt.timestamp() * 1_000_000_000) + + def patched_datetime_now(tz: Optional[Any] = None) -> Any: + base_dt = self.async_ctx.now() + return base_dt.replace(tzinfo=tz) if tz else base_dt + + def patched_datetime_utcnow() -> Any: + return self.async_ctx.now() + + # Apply patches - only patch local imports to maintain context isolation + _asyncio.sleep = cast(Any, patched_sleep) + + # Patch asyncio.gather to a replay-safe, one-shot awaitable wrapper + def _is_workflow_awaitable(obj: Any) -> bool: + try: + from .awaitables import AwaitableBase as _AwaitableBase # local import + + if isinstance(obj, _AwaitableBase): + return True + except Exception: + pass + try: + from durabletask import task as _dt + + if isinstance(obj, _dt.Task): + return True + except Exception: + pass + return False + + class _OneShot: + """Replay-safe one-shot awaitable wrapper. + + Schedules the underlying coroutine/factory exactly once at the + first await, caches either the result or the exception, and on + subsequent awaits simply replays the cached outcome without + re-scheduling any work. This prevents side effects during + orchestrator replays and makes multiple awaits deterministic. + """ + + def __init__(self, factory: Callable[[], Any]) -> None: + self._factory = factory + self._done = False + self._res: Any = None + self._exc: Optional[BaseException] = None + + def __await__(self) -> Any: + if self._done: + + async def _replay() -> Any: + if self._exc is not None: + raise self._exc + return self._res + + return _replay().__await__() + + async def _compute() -> Any: + try: + out = await self._factory() + self._res = out + self._done = True + return out + except BaseException as e: # noqa: BLE001 + self._exc = e + self._done = True + raise + + return _compute().__await__() + + def _patched_gather(*aws: Any, return_exceptions: bool = False) -> Any: + """Replay-safe gather that returns a one-shot awaitable. + + - Empty input returns a cached empty list. + - If all inputs are workflow awaitables, uses WhenAllAwaitable (fan-out) + and caches the combined result. + - Mixed inputs: workflow awaitables are batched via WhenAll (fan-out), then + native awaitables are awaited sequentially; results are merged in the + original order. return_exceptions is honored for both groups. + + The returned object can be awaited multiple times safely without + re-scheduling underlying operations. + """ + # Empty gather returns [] and can be awaited multiple times safely + if not aws: + + async def _empty() -> list[Any]: + return [] + + return _OneShot(_empty) + + # If all awaitables are workflow awaitables or durable tasks, map to when_all (fan-out best scenario) + if all(_is_workflow_awaitable(a) for a in aws): + + async def _await_when_all() -> Any: + from .awaitables import WhenAllAwaitable # local import to avoid cycles + + combined: Any = WhenAllAwaitable(list(aws)) + return await combined + + return _OneShot(_await_when_all) + + # Mixed inputs: fan-out workflow awaitables via WhenAll, then await native sequentially; merge preserving order + async def _run_mixed() -> list[Any]: + from .awaitables import AwaitableBase as _AwaitableBase + from .awaitables import SwallowExceptionAwaitable, WhenAllAwaitable + + items: list[Any] = list(aws) + total = len(items) + # Partition into workflow vs native + wf_indices: list[int] = [] + wf_items: list[Any] = [] + native_indices: list[int] = [] + native_items: list[Any] = [] + for idx, it in enumerate(items): + if _is_workflow_awaitable(it): + wf_indices.append(idx) + wf_items.append(it) + else: + native_indices.append(idx) + native_items.append(it) + merged: list[Any] = [None] * total + # Fan-out workflow group first (optionally swallow exceptions for AwaitableBase entries) + if wf_items: + wf_group: list[Any] = [] + if return_exceptions: + for it in wf_items: + if isinstance(it, _AwaitableBase): + wf_group.append(SwallowExceptionAwaitable(it)) + else: + wf_group.append(it) + else: + wf_group = wf_items + wf_results: list[Any] = await WhenAllAwaitable(wf_group) # type: ignore[assignment] + for pos, val in zip(wf_indices, wf_results, strict=False): + merged[pos] = val + # Then process native sequentially, honoring return_exceptions + for pos, it in zip(native_indices, native_items, strict=False): + try: + merged[pos] = await it + except Exception as e: # noqa: BLE001 + if return_exceptions: + merged[pos] = e + else: + raise + return merged + + return _OneShot(_run_mixed) + + if self.originals.get("asyncio.gather") is not None: + # Assign a fresh closure each enter so identity differs per context + def _patched_gather_wrapper_factory() -> Callable[..., Any]: + def _patched_gather_wrapper(*aws: Any, return_exceptions: bool = False) -> Any: + return _patched_gather(*aws, return_exceptions=return_exceptions) + + return _patched_gather_wrapper + + _asyncio.gather = cast(Any, _patched_gather_wrapper_factory()) + + if self.mode == SandboxMode.STRICT and hasattr(_asyncio, "create_task"): + + def _blocked_create_task(*args: Any, **kwargs: Any) -> None: + # If a coroutine object was already created by caller (e.g., create_task(dummy_coro())), close it + try: + import inspect as _inspect + + if args and _inspect.iscoroutine(args[0]) and hasattr(args[0], "close"): + try: + args[0].close() + except Exception: + pass + except Exception: + pass + raise SandboxViolationError( + "asyncio.create_task is not allowed in workflows (strict mode)", + violation_type="blocked_operation", + suggested_alternative="use workflow awaitables instead", + ) + + _asyncio.create_task = cast(Any, _blocked_create_task) + + _random.random = cast(Any, patched_random) + _random.randrange = cast(Any, patched_randrange) + _random.randint = cast(Any, patched_randint) + _random.getrandbits = cast(Any, patched_getrandbits) + _uuid_mod.uuid4 = cast(Any, patched_uuid4) + _time_mod.time = cast(Any, patched_time) + + if self.originals["time.time_ns"] is not None: + _time_mod.time_ns = cast(Any, patched_time_ns) + + # Note: datetime.datetime is immutable, so we can't patch it directly + # This is a limitation of the current sandboxing approach + # Users should use ctx.now() instead of datetime.now() in workflows + + # Apply strict mode blocks + if self.mode == SandboxMode.STRICT: + import builtins + import os as _os + import secrets as _secrets + + def _blocked_open(*args: Any, **kwargs: Any) -> Any: + raise SandboxViolationError( + "File I/O operations are not allowed in workflows (strict mode)", + violation_type="blocked_operation", + suggested_alternative="use activities for I/O operations", + ) + + def _blocked_urandom(*args: Any, **kwargs: Any) -> Any: + raise SandboxViolationError( + "os.urandom is not allowed in workflows (strict mode)", + violation_type="blocked_operation", + suggested_alternative="ctx.random().randbytes()", + ) + + def _blocked_secrets(*args: Any, **kwargs: Any) -> Any: + raise SandboxViolationError( + "secrets module is not allowed in workflows (strict mode)", + violation_type="blocked_operation", + suggested_alternative="ctx.random() methods", + ) + + builtins.open = cast(Any, _blocked_open) + if self.originals["os.urandom"] is not None: + _os.urandom = cast(Any, _blocked_urandom) + if self.originals["secrets.token_bytes"] is not None: + _secrets.token_bytes = cast(Any, _blocked_secrets) + if self.originals["secrets.token_hex"] is not None: + _secrets.token_hex = cast(Any, _blocked_secrets) + + def _restore_originals(self) -> None: + """Restore original functions after sandboxing.""" + import asyncio as _asyncio2 + import random as _random2 + import time as _time2 + import uuid as _uuid2 + + _asyncio2.sleep = cast(Any, self.originals["asyncio.sleep"]) + if self.originals["asyncio.gather"] is not None: + _asyncio2.gather = cast(Any, self.originals["asyncio.gather"]) + if self.originals["asyncio.create_task"] is not None: + _asyncio2.create_task = cast(Any, self.originals["asyncio.create_task"]) + _random2.random = cast(Any, self.originals["random.random"]) + _random2.randrange = cast(Any, self.originals["random.randrange"]) + _random2.getrandbits = cast(Any, self.originals["random.getrandbits"]) + _uuid2.uuid4 = cast(Any, self.originals["uuid.uuid4"]) + _time2.time = cast(Any, self.originals["time.time"]) + + if self.originals["time.time_ns"] is not None: + _time2.time_ns = cast(Any, self.originals["time.time_ns"]) + + # Note: datetime.datetime is immutable, so we can't restore it + # This is a limitation of the current sandboxing approach + + # Restore strict mode blocks + if self.mode == SandboxMode.STRICT: + import builtins + import os as _os + import secrets as _secrets + + builtins.open = cast(Any, self.originals["builtins.open"]) + if self.originals["os.urandom"] is not None: + _os.urandom = cast(Any, self.originals["os.urandom"]) + if self.originals["secrets.token_bytes"] is not None: + _secrets.token_bytes = cast(Any, self.originals["secrets.token_bytes"]) + if self.originals["secrets.token_hex"] is not None: + _secrets.token_hex = cast(Any, self.originals["secrets.token_hex"]) + + +@contextlib.contextmanager +def _sandbox_scope(async_ctx: Any, mode: Union[str, SandboxMode]) -> Any: + """ + Create a sandbox context for deterministic workflow execution. + + Args: + async_ctx: The async workflow context + mode: Sandbox mode ('off', 'best_effort', 'strict') + + Yields: + None + + Raises: + ValueError: If mode is invalid + SandboxViolationError: If non-deterministic operations are detected in strict mode + """ + mode = SandboxMode.from_string(mode) + # Check for global disable (captured at module load to avoid non-determinism detection) + if mode != SandboxMode.OFF and _DISABLE_DETECTION: + mode = SandboxMode.OFF + + with _Sandbox(async_ctx, mode): + yield diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 8b67219..2b0d06c 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -121,6 +121,16 @@ def new_sub_orchestration_failed_event(event_id: int, ex: Exception) -> pb.Histo ) +def is_orchestration_terminal_status(status: pb.OrchestrationStatus) -> bool: + # https://github.com/dapr/durabletask-go/blob/7f28b2408db77ed48b1b03ecc71624fc456ccca3/api/orchestration.go#L196-L201 + return status in [ + pb.ORCHESTRATION_STATUS_COMPLETED, + pb.ORCHESTRATION_STATUS_FAILED, + pb.ORCHESTRATION_STATUS_TERMINATED, + pb.ORCHESTRATION_STATUS_CANCELED, + ] + + def new_failure_details(ex: Exception) -> pb.TaskFailureDetails: return pb.TaskFailureDetails( errorType=type(ex).__name__, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 3adb6b1..4fe6d73 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -102,7 +102,7 @@ def get_logger( # Add a default log handler if none is provided if log_handler is None: log_handler = logging.StreamHandler() - log_handler.setLevel(logging.INFO) + log_handler.setLevel(logging.DEBUG) logger.handlers.append(log_handler) # Set a default log formatter to our handler if none is provided diff --git a/durabletask/task.py b/durabletask/task.py index 0b27b6f..691c2de 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -7,7 +7,7 @@ import math from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union +from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar, Union import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb @@ -499,7 +499,10 @@ def task_id(self) -> int: Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]] # Activities are simple functions that can be scheduled by orchestrators -Activity = Callable[[ActivityContext, TInput], TOutput] +Activity = Union[ + Callable[[ActivityContext, TInput], TOutput], + Callable[[ActivityContext, TInput], Awaitable[TOutput]], +] class RetryPolicy: diff --git a/durabletask/worker.py b/durabletask/worker.py index 29d67fc..d5aad8d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType -from typing import Any, Generator, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Generator, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import empty_pb2 @@ -20,6 +20,7 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import deterministic, task +from durabletask.aio import AsyncWorkflowContext, CoroutineOrchestratorRunner, SandboxMode from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar("TInput") @@ -96,6 +97,60 @@ def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None: self.orchestrators[name] = fn + # Internal helper: register async orchestrators directly on the registry. + # Primarily for unit tests and direct executor usage. For production, prefer + # using TaskHubGrpcWorker.add_async_orchestrator(), which wraps and registers + # on this registry under the hood. + def add_async_orchestrator( + self, + fn: Optional[Callable[[AsyncWorkflowContext, Any], Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: str | SandboxMode = "best_effort", + ) -> Union[str, Callable[[Callable[[AsyncWorkflowContext, Any], Any]], str]]: + """Registers an async orchestrator by wrapping it with the coroutine driver. + + Can be used as: + - Simple decorator: @registry.add_async_orchestrator + - Decorator with args: @registry.add_async_orchestrator(sandbox_mode="strict") + - Direct call: registry.add_async_orchestrator(my_func, name="MyOrch") + """ + + def _register(func: Callable[[AsyncWorkflowContext, Any], Any]) -> str: + runner = CoroutineOrchestratorRunner(func, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, input_data: Any): + async_ctx = AsyncWorkflowContext(ctx) + gen = runner.to_generator(async_ctx, input_data) + result = None + while True: + try: + task_obj = gen.send(result) + except StopIteration as stop: + return stop.value + try: + result = yield task_obj + except Exception as e: + try: + result = gen.throw(e) + except StopIteration as stop: + return stop.value + + orch_name = name + if orch_name is None: + orch_name = task.get_name(func) if hasattr(func, "__name__") else None + if not orch_name: + raise ValueError("A non-empty orchestrator name is required.") + self.add_named_orchestrator(orch_name, generator_orchestrator) + return orch_name + + # If fn is provided, register directly (used as @decorator or direct call) + if fn is not None: + return _register(fn) + + # If fn is None, return decorator (used as @decorator(args)) + return _register + def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]: return self.orchestrators.get(name) @@ -266,10 +321,76 @@ def __exit__(self, type, value, traceback): self.stop() def add_orchestrator(self, fn: task.Orchestrator) -> str: - """Registers an orchestrator function with the worker.""" + """Registers an orchestrator function with the worker. + + Automatically detects async functions and registers them as async orchestrators. + """ if self._is_running: raise RuntimeError("Orchestrators cannot be added while the worker is running.") - return self._registry.add_orchestrator(fn) + + # Auto-detect coroutine functions and delegate to async registration + if inspect.iscoroutinefunction(fn): + return self.add_async_orchestrator(fn) + else: + return self._registry.add_orchestrator(fn) + + # Async orchestrator support (opt-in) + + def add_async_orchestrator( + self, + fn: Optional[Callable[[AsyncWorkflowContext, Any], Any]] = None, + *, + name: Optional[str] = None, + sandbox_mode: str | SandboxMode = "best_effort", + ) -> Union[str, Callable[[Callable[[AsyncWorkflowContext, Any], Any]], str]]: + """Registers an async orchestrator by wrapping it with the coroutine driver. + + The provided coroutine function must only await awaitables created from + `AsyncWorkflowContext` (activities, timers, external events, when_any/all). + + Can be used as: + - Simple decorator: @worker.add_async_orchestrator + - Decorator with args: @worker.add_async_orchestrator(sandbox_mode="strict") + - Direct call: worker.add_async_orchestrator(my_func, name="MyOrch") + """ + + def _register(func: Callable[[AsyncWorkflowContext, Any], Any]) -> str: + if self._is_running: + raise RuntimeError("Orchestrators cannot be added while the worker is running.") + + runner = CoroutineOrchestratorRunner(func, sandbox_mode=sandbox_mode) + + def generator_orchestrator(ctx: task.OrchestrationContext, input_data: Any): + async_ctx = AsyncWorkflowContext(ctx) + gen = runner.to_generator(async_ctx, input_data) + result = None + while True: + try: + task_obj = gen.send(result) + except StopIteration as stop: + return stop.value + try: + result = yield task_obj + except Exception as e: + try: + result = gen.throw(e) + except StopIteration as stop: + return stop.value + + orch_name = name + if orch_name is None: + orch_name = task.get_name(func) if hasattr(func, "__name__") else None + if orch_name is None: + raise ValueError("A non-empty orchestrator name is required.") + self._registry.add_named_orchestrator(orch_name, generator_orchestrator) + return orch_name + + # If fn is provided, register directly (used as @decorator or direct call) + if fn is not None: + return _register(fn) + + # If fn is None, return decorator (used as @decorator(args)) + return _register def add_activity(self, fn: task.Activity) -> str: """Registers an activity function with the worker.""" @@ -571,7 +692,7 @@ def _execute_orchestrator( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" ) - def _execute_activity( + async def _execute_activity( self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, @@ -580,7 +701,7 @@ def _execute_activity( instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute(instance_id, req.name, req.taskId, req.input.value) + result = await executor.execute(instance_id, req.name, req.taskId, req.input.value) res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, @@ -611,10 +732,11 @@ class _RuntimeOrchestrationContext( _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] - def __init__(self, instance_id: str): + def __init__(self, instance_id: str, workflow_name: Optional[str] = None): super().__init__() self._generator = None self._is_replaying = True + self._is_suspended = False self._is_complete = False self._result = None self._pending_actions: dict[int, pb.OrchestratorAction] = {} @@ -622,6 +744,7 @@ def __init__(self, instance_id: str): self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id + self._workflow_name = workflow_name self._app_id = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} @@ -765,6 +888,15 @@ def current_utc_datetime(self, value: datetime): def is_replaying(self) -> bool: return self._is_replaying + @property + def is_suspended(self) -> bool: + return self._is_suspended + + @property + def workflow_name(self) -> Optional[str]: + """Get the workflow name.""" + return self._workflow_name + def set_custom_status(self, custom_status: Any) -> None: self._encoded_custom_status = ( shared.to_json(custom_status) if custom_status is not None else None @@ -945,7 +1077,19 @@ def execute( "The new history event list must have at least one event in it." ) - ctx = _RuntimeOrchestrationContext(instance_id) + # Extract workflow name from execution started event if available + workflow_name: Optional[str] = None + for event in old_events: + if event.HasField("executionStarted"): + workflow_name = event.executionStarted.name + break + if workflow_name is None: + for event in new_events: + if event.HasField("executionStarted"): + workflow_name = event.executionStarted.name + break + + ctx = _RuntimeOrchestrationContext(instance_id, workflow_name=workflow_name) try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -992,7 +1136,7 @@ def execute( return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: - if self._is_suspended and _is_suspendable(event): + if self._is_suspended and _is_suspendable(event) and not ctx.is_replaying: # We are suspended, so we need to buffer this event until we are resumed self._suspended_events.append(event) return @@ -1260,10 +1404,12 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") self._is_suspended = True + ctx._is_suspended = True elif event.HasField("executionResumed") and self._is_suspended: if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Resuming execution.") self._is_suspended = False + ctx._is_suspended = False for e in self._suspended_events: self.process_event(ctx, e) self._suspended_events = [] @@ -1295,7 +1441,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - def execute( + async def execute( self, orchestration_id: str, name: str, @@ -1313,8 +1459,11 @@ def execute( activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext(orchestration_id, task_id) - # Execute the activity function - activity_output = fn(ctx, activity_input) + # Execute the activity function (sync or async) + if inspect.iscoroutinefunction(fn): + activity_output = await fn(ctx, activity_input) + else: + activity_output = fn(ctx, activity_input) encoded_output = shared.to_json(activity_output) if activity_output is not None else None chars = len(encoded_output) if encoded_output else 0 diff --git a/pyproject.toml b/pyproject.toml index 6626bc2..55f9e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ requires-python = ">=3.10" license = {file = "LICENSE"} readme = "README.md" dependencies = [ - "grpcio", + "grpcio>=1.75.1", "protobuf>=6.31.1,<7.0.0", # follows grpcio generation version https://github.com/grpc/grpc/blob/v1.75.1/tools/distrib/python/grpcio_tools/setup.py "asyncio" ] diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..887f303 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,370 @@ +# Testing Guide + +This directory contains comprehensive tests for the durabletask-python SDK, including both unit tests and end-to-end (E2E) tests. + +## Quick Start + +### Install Dapr CLI (if not already installed) + +```bash +curl -fsSL https://raw.githubusercontent.com/dapr/cli/master/install/install.sh | /bin/bash +``` + +# Initialize Dapr (one-time setup) +dapr init + +# Start Dapr sidecar for testing (uses default port 4001) in another terminal +# It uses a redis statestore for the workflow db. This is usually installed when doing `dapr init` +dapr run \ + --app-id test-app \ + --dapr-grpc-port 4001 \ + --resources-path ./examples/components + + +```bash +# Install tox if not already installed +pip install tox + +# Install dependencies +pip install -r dev-requirements.txt + +# Run unit tests with tox (recommended - uses clean environments) +tox -e py310 + +# Run E2E tests with tox (requires sidecar - see setup below) +tox -e py310-e2e + +# Test with specific Python version +tox -e py311 # or py312, py313, etc. + +# Run tests with coverage +tox -e py310 +coverage report +``` + +## Test Categories + +### Unit Tests +- **No external dependencies** - Run without sidecar +- **Fast execution** - Suitable for development and CI +- **Isolated testing** - Mock external dependencies + +```bash +# Run unit tests with tox (recommended) +tox -e py310 + +# Or test multiple Python versions +tox -e py310,py311,py312 +``` + +### End-to-End (E2E) Tests +- **Require sidecar** - Need running Dapr sidecar +- **Full integration** - Test complete workflow execution +- **Slower execution** - Real network calls and orchestration + +```bash +# Run E2E tests with tox (requires sidecar setup) +tox -e py310-e2e + +# Or test multiple Python versions +tox -e py310-e2e,py311-e2e +``` + +## Sidecar Setup for E2E Tests + +E2E tests require a running Dapr sidecar. The SDK connects to port **4001** by default. + +### Dapr Sidecar Setup (Recommended) + +```bash +# Install Dapr CLI (if not already installed) +curl -fsSL https://raw.githubusercontent.com/dapr/cli/master/install/install.sh | /bin/bash + +# Initialize Dapr (one-time setup) +dapr init + +# Start Dapr sidecar for testing (uses default port 4001) +# It uses a redis statestore for the workflow db. This is usually installed when doing `dapr init` +dapr run \ + --app-id test-app \ + --dapr-grpc-port 4001 \ + --resources-path ./examples/components + + +**Why Dapr:** +- **Production parity**: Same runtime as deployed applications +- **Full Dapr features**: State stores, pub/sub, bindings, etc. +- **Real workflow backend**: Actual Dapr workflow engine +- **Debugging**: Same logging and tracing as production + +## Configuration + +### Environment Variables + +The SDK connects to **localhost:4001** by default. Override for non-default configurations: + +```bash +# Connect to custom endpoint (format: host:port) +export DAPR_GRPC_ENDPOINT=localhost:50001 + + +### Test-Specific Configuration + +```bash +# Enable debug logging for tests +export DAPR_WF_DEBUG=true +export DT_DEBUG=true + +# Disable non-determinism detection globally +export DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true +``` + +## Running Specific Test Suites + +### Core Functionality Tests +```bash +# Run all unit tests (recommended) +tox -e py310 + +# Run specific test file (use pytest directly if needed) +python -m pytest tests/durabletask/test_orchestration_executor.py -v + +# Run specific test case in test file +python -m pytest tests/durabletask/test_orchestration_executor.py::test_fan_in -v + +# Run specific test pattern +python -m pytest -k "orchestration" -v +``` + +### Async Workflow Tests +```bash +# Run async-specific tests +python -m pytest tests/aio -v + +# Run determinism tests +python -m pytest tests/durabletask/test_deterministic.py -v +``` + +### End-to-End Tests +```bash +# Full E2E test suite with tox (requires sidecar on port 4001) +tox -e py310-e2e + +# Run with custom endpoint +DAPR_GRPC_ENDPOINT=localhost:50001 tox -e py310-e2e + +# Run specific E2E test +python -m pytest tests/durabletask/test_orchestration_e2e.py -v +``` + + +## Test Development Guidelines + +### Writing Unit Tests + +1. **Use mocks** for external dependencies +2. **Test edge cases** and error conditions +3. **Keep tests fast** and isolated +4. **Use descriptive test names** that explain the scenario + +```python +def test_async_workflow_context_timeout_with_cancellation(): + """Test that timeout properly cancels ongoing operations.""" + # Test implementation +``` + +### Writing E2E Tests + +1. **Mark with `@pytest.mark.e2e`** decorator +2. **Use unique orchestration names** to avoid conflicts +3. **Clean up resources** in test teardown +4. **Test realistic scenarios** end-to-end + +```python +@pytest.mark.e2e +def test_complex_workflow_with_retries(): + """Test complete workflow with retry policies and error handling.""" + # Test implementation +``` + +### Enhanced Async Tests + +1. **Test both sync and async paths** when applicable +2. **Verify determinism** in replay scenarios +3. **Test sandbox modes** (`off`, `best_effort`, `strict`) +4. **Include performance considerations** + +```python +def test_sandbox_mode_performance_impact(): + """Verify sandbox modes have expected performance characteristics.""" + # Test implementation +``` + +## Debugging Tests + +### Enable Debug Logging + +```bash +# Enable comprehensive debug logging +export DAPR_WF_DEBUG=true +export DT_DEBUG=true + +# Run tests with verbose output using tox +tox -e py310 -- -v -s + +# Or run specific test directly with pytest +python -m pytest tests/durabletask/test_deterministic.py -v -s +``` + +### Debug Specific Features + +```bash +# Debug specific test with pytest +python -m pytest tests/aio/test_context.py::TestAsyncWorkflowContext::test_deterministic_uuid -v -s + +# Debug with tox and custom pytest args +tox -e py310 -- tests/aio/test_context.py -v -s +``` + +### Common Issues and Solutions + +#### Connection Issues +```bash +# Check if Dapr sidecar is running +dapr list + +# Verify port 4001 is listening +lsof -i :4001 + +# Test with custom endpoint +DAPR_GRPC_ENDPOINT=localhost:50001 tox -e py310-e2e +``` + +#### Import Issues +```bash +# Install in development mode +pip install -e . + +# Verify installation +python -c "import durabletask; print(durabletask.__file__)" +``` + + +### Local CI Simulation + +```bash +# Simulate CI environment locally with tox +pip install tox + +# Start Dapr sidecar +dapr run \ + --app-id test-app \ + --dapr-grpc-port 4001 \ + --log-level debug & +sleep 10 + +# Run tests with tox +tox -e py310 +tox -e py310-e2e +``` + +## Performance Testing + +### Benchmarking + +```bash +# Run performance-sensitive tests with tox +tox -e py310 -- tests/aio -v + +# Profile test execution +python -m cProfile -o profile.stats -m pytest tests/aio -v +``` + +### Load Testing + +```bash +# Run concurrency tests with tox +tox -e py310 -- tests/durabletask/test_worker_concurrency_loop.py -v +``` + +## Contributing Guidelines + +### Before Submitting Tests + +1. **Run the full test suite**: + ```bash + tox -e py310 + tox -e py310-e2e # with sidecar running + ``` + +2. **Check code formatting and linting**: + ```bash + tox -e ruff + ``` + +3. **Test with multiple Python versions**: + ```bash + tox -e py310,py311,py312 + ``` + +### Test Coverage + +Maintain high test coverage for new features: + +```bash +# Generate coverage report +tox -e py310 +coverage report + +# Generate HTML coverage report +coverage html +open htmlcov/index.html +``` + +### Test Organization + +- **Unit tests**: `test_*.py` files without `@pytest.mark.e2e` +- **E2E tests**: `test_*_e2e.py` files or tests marked with `@pytest.mark.e2e` +- **Feature tests**: Group related functionality under `tests/aio/` +- **Integration tests**: Test interactions between components + +### Documentation + +- **Document complex test scenarios** with clear comments +- **Include setup/teardown requirements** in test docstrings +- **Explain non-obvious test assertions** +- **Update this README** when adding new test categories + +## Troubleshooting + +### Common Test Failures + +1. **Connection refused**: Sidecar not running or wrong port +2. **Timeout errors**: Increase timeout or check sidecar performance +3. **Import errors**: Run `pip install -e .` to install in development mode +4. **Flaky tests**: Check for race conditions or resource cleanup issues + +### Getting Help + +- **Check existing issues** in the repository +- **Run tests with `-v -s`** for detailed output +- **Enable debug logging** with environment variables +- **Isolate failing tests** by running them individually + +### Reporting Issues + +When reporting test failures, include: + +1. **Python version**: `python --version` +2. **Test command**: Exact command that failed +3. **Environment variables**: Relevant configuration +4. **Sidecar setup**: How the sidecar was started +5. **Full error output**: Complete traceback and logs + +## Additional Resources + +- [Main README](../README.md) - General SDK documentation +- [ASYNC_ENHANCEMENTS.md](../ASYNC_ENHANCEMENTS.md) - Enhanced async features +- [Examples](../examples/) - Working code samples +- [Makefile](../Makefile) - Build and test commands +- [tox.ini](../tox.ini) - Multi-environment testing configuration diff --git a/tests/aio/__init__.py b/tests/aio/__init__.py new file mode 100644 index 0000000..e8c1cfc --- /dev/null +++ b/tests/aio/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Tests for the durabletask.aio package. + +This package contains comprehensive tests for all async workflow functionality +including deterministic utilities, awaitables, driver, sandbox, and context. +""" diff --git a/tests/aio/compatibility_utils.py b/tests/aio/compatibility_utils.py new file mode 100644 index 0000000..d47a7b1 --- /dev/null +++ b/tests/aio/compatibility_utils.py @@ -0,0 +1,243 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compatibility testing utilities for AsyncWorkflowContext. + +This module provides utilities for testing and validating AsyncWorkflowContext +compatibility with OrchestrationContext. These are testing/validation utilities +and should not be part of the main production code. +""" + +from __future__ import annotations + +import inspect +import warnings +from datetime import datetime +from typing import Any, Dict +from unittest.mock import Mock + +from durabletask import task + + +class CompatibilityChecker: + """ + Utility class for checking AsyncWorkflowContext compatibility with OrchestrationContext. + + This class provides methods to validate that AsyncWorkflowContext maintains + all required properties and methods for compatibility. + """ + + @staticmethod + def check_protocol_compliance(context_class: type) -> bool: + """ + Check if a context class complies with the OrchestrationContextProtocol. + + Args: + context_class: The context class to check + + Returns: + True if the class complies with the protocol, False otherwise + """ + # For protocols with properties, we need to check the class structure + # rather than using issubclass() which doesn't work with property protocols + + # Get all required members from the protocol + required_properties = [ + "instance_id", + "current_utc_datetime", + "is_replaying", + "workflow_name", + "is_suspended", + ] + + required_methods = [ + "set_custom_status", + "create_timer", + "call_activity", + "call_sub_orchestrator", + "wait_for_external_event", + "continue_as_new", + ] + + # Check if the class has all required members + for prop_name in required_properties: + if not hasattr(context_class, prop_name): + return False + + for method_name in required_methods: + if not hasattr(context_class, method_name): + return False + + return True + + @staticmethod + def validate_context_compatibility(context_instance: Any) -> list[str]: + """ + Validate that a context instance has all required properties and methods. + + Args: + context_instance: The context instance to validate + + Returns: + List of missing properties/methods (empty if fully compatible) + """ + missing_items = [] + + # Check required properties + required_properties = [ + "instance_id", + "current_utc_datetime", + "is_replaying", + "workflow_name", + "is_suspended", + ] + + for prop_name in required_properties: + if not hasattr(context_instance, prop_name): + missing_items.append(f"property: {prop_name}") + + # Check required methods + required_methods = [ + "set_custom_status", + "create_timer", + "call_activity", + "call_sub_orchestrator", + "wait_for_external_event", + "continue_as_new", + ] + + for method_name in required_methods: + if not hasattr(context_instance, method_name): + missing_items.append(f"method: {method_name}") + elif not callable(getattr(context_instance, method_name)): + missing_items.append(f"method: {method_name} (not callable)") + + return missing_items + + @staticmethod + def compare_with_orchestration_context(context_instance: Any) -> Dict[str, Any]: + """ + Compare a context instance with OrchestrationContext interface. + + Args: + context_instance: The context instance to compare + + Returns: + Dictionary with comparison results + """ + # Get OrchestrationContext members + base_members = {} + for name, member in inspect.getmembers(task.OrchestrationContext): + if not name.startswith("_"): + if isinstance(member, property): + base_members[name] = "property" + elif inspect.isfunction(member): + base_members[name] = "method" + + # Check context instance + context_members = {} + missing_members = [] + extra_members = [] + + for name, member_type in base_members.items(): + if hasattr(context_instance, name): + context_members[name] = member_type + else: + missing_members.append(f"{member_type}: {name}") + + # Find extra members (enhancements) + for name, member in inspect.getmembers(context_instance): + if ( + not name.startswith("_") + and name not in base_members + and (isinstance(member, property) or callable(member)) + ): + member_type = "property" if isinstance(member, property) else "method" + extra_members.append(f"{member_type}: {name}") + + return { + "base_members": base_members, + "context_members": context_members, + "missing_members": missing_members, + "extra_members": extra_members, + "is_compatible": len(missing_members) == 0, + } + + @staticmethod + def warn_about_compatibility_issues(context_instance: Any) -> None: + """ + Issue warnings about any compatibility issues found. + + Args: + context_instance: The context instance to check + """ + missing_items = CompatibilityChecker.validate_context_compatibility(context_instance) + + if missing_items: + warning_msg = ( + f"AsyncWorkflowContext compatibility issue: missing {', '.join(missing_items)}. " + "This may cause issues with upstream merges or when used as OrchestrationContext." + ) + warnings.warn(warning_msg, UserWarning, stacklevel=3) + + +def validate_runtime_compatibility(context_instance: Any, *, strict: bool = False) -> bool: + """ + Validate runtime compatibility of a context instance. + + Args: + context_instance: The context instance to validate + strict: If True, raise exception on compatibility issues; if False, just warn + + Returns: + True if compatible, False otherwise + + Raises: + RuntimeError: If strict=True and compatibility issues are found + """ + missing_items = CompatibilityChecker.validate_context_compatibility(context_instance) + + if missing_items: + error_msg = ( + f"Runtime compatibility check failed: {context_instance.__class__.__name__} " + f"is missing {', '.join(missing_items)}" + ) + + if strict: + raise RuntimeError(error_msg) + else: + warnings.warn(error_msg, UserWarning, stacklevel=2) + return False + + return True + + +def check_async_context_compatibility() -> Dict[str, Any]: + """ + Check AsyncWorkflowContext compatibility with OrchestrationContext. + + Returns: + Dictionary with detailed compatibility information + """ + from durabletask.aio import AsyncWorkflowContext + + # Create a mock base context for testing + mock_base_ctx = Mock(spec=task.OrchestrationContext) + mock_base_ctx.instance_id = "test" + mock_base_ctx.current_utc_datetime = datetime.now() + mock_base_ctx.is_replaying = False + + # Create AsyncWorkflowContext instance + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # Perform compatibility check + return CompatibilityChecker.compare_with_orchestration_context(async_ctx) diff --git a/tests/aio/test_app_id_propagation.py b/tests/aio/test_app_id_propagation.py new file mode 100644 index 0000000..ff7649b --- /dev/null +++ b/tests/aio/test_app_id_propagation.py @@ -0,0 +1,132 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for app_id propagation through aio AsyncWorkflowContext and awaitables. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional +from unittest.mock import Mock + +import durabletask.task as dt_task +from durabletask.aio import AsyncWorkflowContext + + +def test_activity_app_id_passed_to_base_ctx_when_supported(): + base_ctx = Mock(spec=dt_task.OrchestrationContext) + + # Mock call_activity signature to include app_id and metadata + def _call_activity( + activity: Any, + *, + input: Any = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ): + # Return a durable task-like object; tests only need that it's called with kwargs + return dt_task.when_all([]) + + base_ctx.call_activity = _call_activity # type: ignore[attr-defined] + + async_ctx = AsyncWorkflowContext(base_ctx) + + awaitable = async_ctx.call_activity( + "do_work", input={"x": 1}, retry_policy=None, app_id="target-app", metadata={"k": "v"} + ) + task_obj = awaitable._to_task() + assert isinstance(task_obj, dt_task.Task) + + +def test_sub_orchestrator_app_id_passed_to_base_ctx_when_supported(): + base_ctx = Mock(spec=dt_task.OrchestrationContext) + + # Mock call_sub_orchestrator signature to include app_id and metadata + def _call_sub( + orchestrator: Any, + *, + input: Any = None, + instance_id: Optional[str] = None, + retry_policy: Any = None, + app_id: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + ): + return dt_task.when_all([]) + + base_ctx.call_sub_orchestrator = _call_sub # type: ignore[attr-defined] + + async_ctx = AsyncWorkflowContext(base_ctx) + + awaitable = async_ctx.sub_orchestrator( + "child_wf", + input=None, + instance_id="abc", + retry_policy=None, + app_id="target-app", + metadata={"k2": "v2"}, + ) + task_obj = awaitable._to_task() + assert isinstance(task_obj, dt_task.Task) + + +def test_activity_app_id_not_passed_when_not_supported(): + base_ctx = Mock(spec=dt_task.OrchestrationContext) + + # Mock call_activity without app_id support + def _call_activity( + activity: Any, + *, + input: Any = None, + retry_policy: Any = None, + metadata: Optional[Dict[str, str]] = None, + ): + return dt_task.when_all([]) + + base_ctx.call_activity = _call_activity # type: ignore[attr-defined] + + async_ctx = AsyncWorkflowContext(base_ctx) + + awaitable = async_ctx.call_activity( + "do_work", input={"x": 1}, retry_policy=None, app_id="target-app", metadata={"k": "v"} + ) + task_obj = awaitable._to_task() + assert isinstance(task_obj, dt_task.Task) + + +def test_sub_orchestrator_app_id_not_passed_when_not_supported(): + base_ctx = Mock(spec=dt_task.OrchestrationContext) + + # Mock call_sub_orchestrator without app_id support + def _call_sub( + orchestrator: Any, + *, + input: Any = None, + instance_id: Optional[str] = None, + retry_policy: Any = None, + metadata: Optional[Dict[str, str]] = None, + ): + return dt_task.when_all([]) + + base_ctx.call_sub_orchestrator = _call_sub # type: ignore[attr-defined] + + async_ctx = AsyncWorkflowContext(base_ctx) + + awaitable = async_ctx.sub_orchestrator( + "child_wf", + input=None, + instance_id="abc", + retry_policy=None, + app_id="target-app", + metadata={"k2": "v2"}, + ) + task_obj = awaitable._to_task() + assert isinstance(task_obj, dt_task.Task) diff --git a/tests/aio/test_async_orchestrator.py b/tests/aio/test_async_orchestrator.py new file mode 100644 index 0000000..b3ca8d5 --- /dev/null +++ b/tests/aio/test_async_orchestrator.py @@ -0,0 +1,501 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import logging +import random +import time +import uuid +from datetime import timedelta # noqa: F401 + +import durabletask.internal.helpers as helpers +from durabletask.worker import _OrchestrationExecutor, _Registry + +TEST_INSTANCE_ID = "async-test-1" + + +def test_async_activity_and_sleep(): + async def orch(ctx, _): + a = await ctx.call_activity("echo", input=1) + await ctx.create_timer(1) + b = await ctx.call_activity("echo", input=a + 1) + return b + + def echo(_, x): + return x + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + activity_name = registry.add_activity(echo) + + # start → schedule first activity + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("scheduleTask") + assert res.actions[0].scheduleTask.name == activity_name + + # complete first activity → expect timer + old_events = new_events + [helpers.new_task_scheduled_event(1, activity_name)] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(1)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("createTimer") + + # fire timer → expect second activity + now_dt = helpers.new_orchestrator_started_event().timestamp.ToDatetime() + old_events = ( + old_events + + new_events + + [ + helpers.new_timer_created_event(2, now_dt), + ] + ) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_timer_fired_event(2, now_dt), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("scheduleTask") + assert res.actions[0].scheduleTask.name == activity_name + + # complete second activity → done + old_events = old_events + new_events + [helpers.new_task_scheduled_event(1, activity_name)] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(2)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_when_all_any_and_events(): + async def orch(ctx, _): + t1 = ctx.call_activity("a", input=1) + t2 = ctx.call_activity("b", input=2) + await ctx.when_all([t1, t2]) + _ = await ctx.when_any([ctx.wait_for_external_event("x"), ctx.create_timer(0.1)]) + return "ok" + + def a(_, x): + return x + + def b(_, x): + return x + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + _ = registry.add_activity(a) + _ = registry.add_activity(b) + + # start → schedule both activities + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 2 and all(a.HasField("scheduleTask") for a in res.actions) + + +def test_async_external_event_immediate_and_buffered(): + async def orch(ctx, _): + val = await ctx.wait_for_external_event("x") + return val + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + + # Start: expect no actions (waiting for event) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 0 + + # Deliver event and complete + old_events = new_events + new_events = [helpers.new_event_raised_event("x", encoded_input=json.dumps(42))] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_sub_orchestrator_completion_and_failure(): + async def child(ctx, x): + return x + + async def parent(ctx, _): + return await ctx.sub_orchestrator(child, input=5) + + registry = _Registry() + child_name = registry.add_async_orchestrator(child) # type: ignore[attr-defined] + parent_name = registry.add_async_orchestrator(parent) # type: ignore[attr-defined] + + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # Start parent → expect createSubOrchestration action + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("createSubOrchestration") + assert res.actions[0].createSubOrchestration.name == child_name + + # Simulate sub-orch created then completed + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(parent_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_sub_orchestration_created_event( + 1, child_name, f"{TEST_INSTANCE_ID}:0001", encoded_input=None + ), + ] + new_events = [helpers.new_sub_orchestration_completed_event(1, encoded_output=json.dumps(5))] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + # Also verify the worker-level wrapper does not surface StopIteration + from durabletask.worker import TaskHubGrpcWorker + + w = TaskHubGrpcWorker() + w.add_async_orchestrator(child, name="child") + w.add_async_orchestrator(parent, name="parent") + + +def test_async_sandbox_sleep_patching_creates_timer(): + async def orch(ctx, _): + await asyncio.sleep(1) + return "done" + + registry = _Registry() + name = registry.add_async_orchestrator(orch, sandbox_mode="best_effort") # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("createTimer") + + +def test_async_sandbox_deterministic_random_uuid_time(): + async def orch(ctx, _): + r = random.random() + u = str(uuid.uuid4()) + t = int(time.time()) + return {"r": r, "u": u, "t": t} + + registry = _Registry() + name = registry.add_async_orchestrator(orch, sandbox_mode="best_effort") # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res1 = exec.execute(TEST_INSTANCE_ID, [], new_events) + out1 = res1.actions[0].completeOrchestration.result.value + + res2 = exec.execute(TEST_INSTANCE_ID, [], new_events) + out2 = res2.actions[0].completeOrchestration.result.value + assert out1 == out2 + + +def test_async_two_activities_no_timer(): + async def orch(ctx, _): + a = await ctx.call_activity("echo", input=1) + b = await ctx.call_activity("echo", input=a + 1) + return b + + def echo(_, x): + return x + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + activity_name = registry.add_activity(echo) + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # start -> schedule first activity + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("scheduleTask") + + # complete first activity -> schedule second + old_events = new_events + [helpers.new_task_scheduled_event(1, activity_name)] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(1)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("scheduleTask") + + # complete second -> done + old_events = old_events + new_events + [helpers.new_task_scheduled_event(1, activity_name)] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(2)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_ctx_metadata_passthrough(): + async def orch(ctx, _): + # Access deterministic metadata via AsyncWorkflowContext + return { + "id": ctx.instance_id, + "replay": ctx.is_replaying, + "susp": ctx.is_suspended, + } + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + out_json = res.actions[0].completeOrchestration.result.value + out = json.loads(out_json) + assert out["id"] == TEST_INSTANCE_ID + assert out["replay"] is False + + +def test_async_when_all_with_mixed_success_and_failure(): + async def orch(ctx, _): + a = ctx.call_activity("ok", input=1) + b = ctx.call_activity("boom", input=2) + c = ctx.call_activity("ok", input=3) + # when_all will fail-fast when b fails + try: + vals = await ctx.when_all([a, b, c]) + return vals + except Exception as e: + return {"error": str(e)} + + def ok(_, x): + return x + + def boom(_, __): + raise RuntimeError("fail!") + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + an_ok = registry.add_activity(ok) + an_boom = registry.add_activity(boom) + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # start -> schedule three activities + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 3 and all(a.HasField("scheduleTask") for a in res.actions) + + # mark scheduled + old_events = new_events + [ + helpers.new_task_scheduled_event(1, an_ok), + helpers.new_task_scheduled_event(2, an_boom), + helpers.new_task_scheduled_event(3, an_ok), + ] + + # complete ok(1), fail boom(2), complete ok(3) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(1)), + helpers.new_task_failed_event(2, RuntimeError("fail!")), + helpers.new_task_completed_event(3, encoded_output=json.dumps(3)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_strict_sandbox_blocks_create_task(): + import asyncio + + import durabletask.internal.helpers as helpers + + async def orch(ctx, _): + # Should be blocked in strict mode during priming + asyncio.create_task(asyncio.sleep(0)) + return 1 + + registry = _Registry() + name = registry.add_async_orchestrator(orch, sandbox_mode="strict") # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + # Expect failureDetails is set due to strict mode error + assert res.actions[0].completeOrchestration.HasField("failureDetails") + + +def test_async_when_any_ignores_losers_deterministically(): + import durabletask.internal.helpers as helpers + + async def orch(ctx, _): + a = ctx.call_activity("a", input=1) + b = ctx.call_activity("b", input=2) + await ctx.when_any([a, b]) + return "done" + + def a(_, x): + return x + + def b(_, x): + return x + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + an = registry.add_activity(a) + bn = registry.add_activity(b) + + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # start -> schedule both + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert len(res.actions) == 2 and all(a.HasField("scheduleTask") for a in res.actions) + + # winner completes -> orchestration should complete; no extra commands emitted to cancel loser + old_events = new_events + [ + helpers.new_task_scheduled_event(1, an), + helpers.new_task_scheduled_event(2, bn), + ] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_task_completed_event(1, encoded_output=json.dumps(1)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_termination_maps_to_cancellation(): + async def orch(ctx, _): + try: + await ctx.create_timer(10) + except Exception as e: + # Should surface as cancellation + return type(e).__name__ + return "unexpected" + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # start -> schedule timer + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert any(a.HasField("createTimer") for a in res.actions) + # Capture the actual timer ID to avoid non-determinism in tests + _ = next(a.id for a in res.actions if a.HasField("createTimer")) + + # terminate -> expect completion with TERMINATED and encoded output preserved + old_events = new_events + new_events = [helpers.new_terminated_event(encoded_output=json.dumps("bye"))] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + assert res.actions[0].completeOrchestration.orchestrationStatus == 5 # TERMINATED + + +def test_async_suspend_sets_flag_and_resumes_without_raising(): + async def orch(ctx, _): + # observe suspension via flag and then continue normally + before = ctx.is_suspended + await ctx.create_timer(0.1) + after = ctx.is_suspended + return {"before": before, "after": after} + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + # start + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + res = exec.execute(TEST_INSTANCE_ID, [], new_events) + assert any(a.HasField("createTimer") for a in res.actions) + timer_id = next(a.id for a in res.actions if a.HasField("createTimer")) + + # suspend, then resume, then fire timer across separate activations, always with orchestratorStarted + now_dt = helpers.new_orchestrator_started_event().timestamp.ToDatetime() + old_events = new_events + new_events = [helpers.new_orchestrator_started_event(), helpers.new_suspend_event()] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert not any(a.HasField("completeOrchestration") for a in res.actions) + + # Confirm timer created after first activation + old_events = old_events + new_events + [helpers.new_timer_created_event(timer_id, now_dt)] + + # Resume activation + new_events = [helpers.new_orchestrator_started_event(), helpers.new_resume_event()] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + old_events = old_events + new_events + + # Timer fires in next activation + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_timer_fired_event(timer_id, now_dt), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") + + +def test_async_suspend_resume_like_generator_test(): + async def orch(ctx, _): + val = await ctx.wait_for_external_event("my_event") + return val + + registry = _Registry() + name = registry.add_async_orchestrator(orch) # type: ignore[attr-defined] + exec = _OrchestrationExecutor(registry, logging.getLogger("tests")) + + old_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] + new_events = [ + helpers.new_suspend_event(), + helpers.new_event_raised_event("my_event", encoded_input=json.dumps(42)), + ] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 0 + + old_events = old_events + new_events + new_events = [helpers.new_resume_event()] + res = exec.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(res.actions) == 1 and res.actions[0].HasField("completeOrchestration") diff --git a/tests/aio/test_asyncio_enhanced_additions.py b/tests/aio/test_asyncio_enhanced_additions.py new file mode 100644 index 0000000..3d727fa --- /dev/null +++ b/tests/aio/test_asyncio_enhanced_additions.py @@ -0,0 +1,340 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for enhanced asyncio compatibility features. +""" + +import asyncio +import os +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + AsyncWorkflowContext, + AsyncWorkflowError, + CoroutineOrchestratorRunner, + SandboxViolationError, + WorkflowFunction, +) +from durabletask.aio.sandbox import _sandbox_scope + + +class TestAsyncWorkflowError: + """Test the enhanced error handling.""" + + def test_basic_error(self): + error = AsyncWorkflowError("Test error") + assert str(error) == "Test error" + + def test_error_with_context(self): + error = AsyncWorkflowError( + "Test error", + instance_id="test-123", + workflow_name="test_workflow", + step="initialization", + ) + expected = "Test error (workflow: test_workflow, instance: test-123, step: initialization)" + assert str(error) == expected + + def test_error_partial_context(self): + error = AsyncWorkflowError("Test error", instance_id="test-123") + assert str(error) == "Test error (instance: test-123)" + + +class TestAsyncWorkflowContext: + """Test enhanced AsyncWorkflowContext features.""" + + def setup_method(self): + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + self.ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_debug_mode_detection(self): + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert ctx._debug_mode is True + + with patch.dict(os.environ, {"DT_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert ctx._debug_mode is True + + with patch.dict(os.environ, {}, clear=True): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert ctx._debug_mode is False + + def test_operation_logging(self): + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + + ctx._log_operation("test_op", {"param": "value"}) + + assert len(ctx._operation_history) == 1 + op = ctx._operation_history[0] + assert op["operation"] == "test_op" + assert op["details"] == {"param": "value"} + assert op["sequence"] == 0 + + def test_get_debug_info(self): + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + ctx._log_operation("test_op", {"param": "value"}) + + debug_info = ctx._get_info_snapshot() + + assert debug_info["instance_id"] == "test-instance-123" + assert len(debug_info["operation_history"]) == 1 + assert debug_info["operation_history"][0]["type"] == "test_op" + + def test_activity_logging(self): + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + + ctx.call_activity("test_activity", input="test") + + assert len(ctx._operation_history) == 1 + op = ctx._operation_history[0] + assert op["operation"] == "activity" + assert op["details"]["function"] == "test_activity" + assert op["details"]["input"] == "test" + + def test_sleep_logging(self): + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + ctx = AsyncWorkflowContext(self.mock_base_ctx) + + ctx.create_timer(5.0) + + assert len(ctx._operation_history) == 1 + op = ctx._operation_history[0] + assert op["operation"] == "sleep" + assert op["details"]["duration"] == 5.0 + + def test_with_timeout(self): + mock_awaitable = Mock() + timeout_awaitable = self.ctx.with_timeout(mock_awaitable, 10.0) + + assert timeout_awaitable is not None + assert hasattr(timeout_awaitable, "_timeout") + + +class TestCoroutineOrchestratorRunner: + """Test enhanced CoroutineOrchestratorRunner features.""" + + def test_orchestrator_validation_success(self): + async def valid_orchestrator(ctx, input_data): + return "result" + + # Should not raise + runner = CoroutineOrchestratorRunner(valid_orchestrator) + assert runner is not None + + def test_orchestrator_validation_not_callable(self): + with pytest.raises(AsyncWorkflowError, match="must be callable"): + CoroutineOrchestratorRunner("not_callable") + + def test_orchestrator_validation_wrong_params(self): + async def wrong_params(): # No parameters - should fail + return "result" + + with pytest.raises(AsyncWorkflowError, match="at least one parameter"): + CoroutineOrchestratorRunner(wrong_params) + + def test_orchestrator_validation_not_async(self): + def not_async(ctx, input_data): + return "result" + + with pytest.raises(AsyncWorkflowError, match="must be an async function"): + CoroutineOrchestratorRunner(not_async) + + def test_enhanced_error_context(self): + async def failing_orchestrator(ctx, input_data): + raise ValueError("Test error") + + runner = CoroutineOrchestratorRunner(failing_orchestrator) + mock_ctx = Mock(spec=dt_task.OrchestrationContext) + mock_ctx.instance_id = "test-123" + async_ctx = AsyncWorkflowContext(mock_ctx) + + gen = runner.to_generator(async_ctx, "input") + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + error = exc_info.value + assert "initialization" in str(error) + assert "test-123" in str(error) + + +class TestEnhancedSandboxing: + """Test enhanced sandboxing capabilities.""" + + def setup_method(self): + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + self.async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_datetime_patching_limitation(self): + # Note: datetime.datetime is immutable and cannot be patched + # This test documents the current limitation + import datetime as dt + + with _sandbox_scope(self.async_ctx, "best_effort"): + # datetime.now cannot be patched due to immutability + # Users should use ctx.now() instead + now_result = dt.datetime.now() + + # This will NOT be the deterministic time (unless by coincidence) + # We just verify that the call works and returns a datetime + assert isinstance(now_result, datetime) + + # The deterministic time is available via ctx.now() + deterministic_time = self.async_ctx.now() + assert isinstance(deterministic_time, datetime) + + # datetime.datetime methods remain unchanged (they can't be patched) + assert hasattr(dt.datetime, "now") + assert hasattr(dt.datetime, "utcnow") + + def test_random_getrandbits_patching(self): + import random + + original_getrandbits = random.getrandbits + + with _sandbox_scope(self.async_ctx, "best_effort"): + # Should use deterministic random + result1 = random.getrandbits(32) + result2 = random.getrandbits(32) + assert isinstance(result1, int) + assert isinstance(result2, int) + + # Should be restored + assert random.getrandbits is original_getrandbits + + def test_strict_mode_file_blocking(self): + with pytest.raises(SandboxViolationError, match="File I/O operations are not allowed"): + with _sandbox_scope(self.async_ctx, "strict"): + open("test.txt", "w") + + def test_strict_mode_urandom_blocking(self): + import os + + if hasattr(os, "urandom"): + with pytest.raises(SandboxViolationError, match="os.urandom is not allowed"): + with _sandbox_scope(self.async_ctx, "strict"): + os.urandom(16) + + def test_strict_mode_secrets_blocking(self): + try: + import secrets + + with pytest.raises(SandboxViolationError, match="secrets module is not allowed"): + with _sandbox_scope(self.async_ctx, "strict"): + secrets.token_bytes(16) + except ImportError: + # secrets module not available, skip test + pass + + def test_asyncio_sleep_patching(self): + original_sleep = asyncio.sleep + + with _sandbox_scope(self.async_ctx, "best_effort"): + # asyncio.sleep should be patched + sleep_awaitable = asyncio.sleep(1.0) + assert hasattr(sleep_awaitable, "__await__") + + # Should be restored + assert asyncio.sleep is original_sleep + + +class TestConcurrencyPrimitives: + """Test enhanced concurrency primitives.""" + + def setup_method(self): + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance" + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + self.ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_timeout_awaitable(self): + from durabletask.aio import TimeoutAwaitable + + mock_awaitable = Mock() + timeout_awaitable = TimeoutAwaitable(mock_awaitable, 5.0, self.ctx) + + assert timeout_awaitable._awaitable is mock_awaitable + assert timeout_awaitable._timeout == 5.0 + assert timeout_awaitable._ctx is self.ctx + + +class TestPerformanceOptimizations: + """Test performance optimizations.""" + + def test_awaitable_slots(self): + from durabletask.aio import ( + ActivityAwaitable, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + SwallowExceptionAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, + ) + + # All awaitable classes should have __slots__ + classes_with_slots = [ + ActivityAwaitable, + SubOrchestratorAwaitable, + SleepAwaitable, + ExternalEventAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, + SwallowExceptionAwaitable, + ] + + for cls in classes_with_slots: + assert hasattr(cls, "__slots__"), f"{cls.__name__} should have __slots__" + + +class TestWorkflowFunctionProtocol: + """Test WorkflowFunction protocol.""" + + def test_valid_workflow_function(self): + async def valid_workflow(ctx: AsyncWorkflowContext, input_data) -> str: + return "result" + + # Should be recognized as WorkflowFunction + assert isinstance(valid_workflow, WorkflowFunction) + + def test_invalid_workflow_function(self): + def not_async_workflow(ctx, input_data): + return "result" + + # Note: runtime_checkable protocols are structural, not nominal + # A function with the right signature will pass isinstance check + # The actual validation happens in CoroutineOrchestratorRunner + # This test documents the current behavior + assert isinstance( + not_async_workflow, WorkflowFunction + ) # This passes due to structural typing + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/aio/test_awaitables.py b/tests/aio/test_awaitables.py new file mode 100644 index 0000000..f666287 --- /dev/null +++ b/tests/aio/test_awaitables.py @@ -0,0 +1,690 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for awaitable classes in durabletask.aio. +""" + +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + ActivityAwaitable, + AwaitableBase, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + SwallowExceptionAwaitable, + TimeoutAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, + WorkflowTimeoutError, +) + + +class TestAwaitableBase: + """Test AwaitableBase functionality.""" + + def test_awaitable_base_abstract(self): + """Test that AwaitableBase cannot be instantiated directly.""" + # AwaitableBase is not technically abstract but should not be used directly + # It will raise NotImplementedError when _to_task is called + awaitable = AwaitableBase() + with pytest.raises(NotImplementedError): + awaitable._to_task() + + def test_awaitable_base_slots(self): + """Test that AwaitableBase has __slots__.""" + assert hasattr(AwaitableBase, "__slots__") + + +class TestActivityAwaitable: + """Test ActivityAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.call_activity.return_value = dt_task.CompletableTask() + self.activity_fn = Mock(__name__="test_activity") + + def test_activity_awaitable_creation(self): + """Test creating an ActivityAwaitable.""" + awaitable = ActivityAwaitable( + self.mock_ctx, + self.activity_fn, + input="test_input", + retry_policy=None, + metadata={"key": "value"}, + ) + + assert awaitable._ctx is self.mock_ctx + assert awaitable._activity_fn is self.activity_fn + assert awaitable._input == "test_input" + assert awaitable._retry_policy is None + assert awaitable._metadata == {"key": "value"} + + def test_activity_awaitable_to_task(self): + """Test converting ActivityAwaitable to task.""" + awaitable = ActivityAwaitable(self.mock_ctx, self.activity_fn, input="test_input") + + task = awaitable._to_task() + + self.mock_ctx.call_activity.assert_called_once_with(self.activity_fn, input="test_input") + assert isinstance(task, dt_task.Task) + + def test_activity_awaitable_with_retry_policy(self): + """Test ActivityAwaitable with retry policy.""" + retry_policy = Mock() + awaitable = ActivityAwaitable( + self.mock_ctx, self.activity_fn, input="test_input", retry_policy=retry_policy + ) + + awaitable._to_task() + + self.mock_ctx.call_activity.assert_called_once_with( + self.activity_fn, input="test_input", retry_policy=retry_policy + ) + + def test_activity_awaitable_slots(self): + """Test that ActivityAwaitable has __slots__.""" + assert hasattr(ActivityAwaitable, "__slots__") + + +class TestSubOrchestratorAwaitable: + """Test SubOrchestratorAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.call_sub_orchestrator.return_value = dt_task.CompletableTask() + self.workflow_fn = Mock(__name__="test_workflow") + + def test_sub_orchestrator_awaitable_creation(self): + """Test creating a SubOrchestratorAwaitable.""" + awaitable = SubOrchestratorAwaitable( + self.mock_ctx, + self.workflow_fn, + input="test_input", + instance_id="test-instance", + retry_policy=None, + metadata={"key": "value"}, + ) + + assert awaitable._ctx is self.mock_ctx + assert awaitable._workflow_fn is self.workflow_fn + assert awaitable._input == "test_input" + assert awaitable._instance_id == "test-instance" + assert awaitable._retry_policy is None + assert awaitable._metadata == {"key": "value"} + + def test_sub_orchestrator_awaitable_to_task(self): + """Test converting SubOrchestratorAwaitable to task.""" + awaitable = SubOrchestratorAwaitable( + self.mock_ctx, self.workflow_fn, input="test_input", instance_id="test-instance" + ) + + task = awaitable._to_task() + + self.mock_ctx.call_sub_orchestrator.assert_called_once_with( + self.workflow_fn, input="test_input", instance_id="test-instance" + ) + assert isinstance(task, dt_task.Task) + + def test_sub_orchestrator_awaitable_slots(self): + """Test that SubOrchestratorAwaitable has __slots__.""" + assert hasattr(SubOrchestratorAwaitable, "__slots__") + + +class TestSleepAwaitable: + """Test SleepAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.create_timer.return_value = dt_task.CompletableTask() + + def test_sleep_awaitable_creation(self): + """Test creating a SleepAwaitable.""" + duration = timedelta(seconds=5) + awaitable = SleepAwaitable(self.mock_ctx, duration) + + assert awaitable._ctx is self.mock_ctx + assert awaitable._duration is duration + + def test_sleep_awaitable_to_task(self): + """Test converting SleepAwaitable to task.""" + duration = timedelta(seconds=5) + awaitable = SleepAwaitable(self.mock_ctx, duration) + + task = awaitable._to_task() + + self.mock_ctx.create_timer.assert_called_once_with(duration) + assert isinstance(task, dt_task.Task) + + def test_sleep_awaitable_with_float(self): + """Test SleepAwaitable with float duration.""" + awaitable = SleepAwaitable(self.mock_ctx, 5.0) + awaitable._to_task() + + self.mock_ctx.create_timer.assert_called_once_with(timedelta(seconds=5.0)) + + def test_sleep_awaitable_with_datetime(self): + """Test SleepAwaitable with datetime.""" + deadline = datetime(2023, 1, 1, 12, 0, 0) + awaitable = SleepAwaitable(self.mock_ctx, deadline) + awaitable._to_task() + + self.mock_ctx.create_timer.assert_called_once_with(deadline) + + def test_sleep_awaitable_slots(self): + """Test that SleepAwaitable has __slots__.""" + assert hasattr(SleepAwaitable, "__slots__") + + +class TestExternalEventAwaitable: + """Test ExternalEventAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.wait_for_external_event.return_value = dt_task.CompletableTask() + + def test_external_event_awaitable_creation(self): + """Test creating an ExternalEventAwaitable.""" + awaitable = ExternalEventAwaitable(self.mock_ctx, "test_event") + + assert awaitable._ctx is self.mock_ctx + assert awaitable._name == "test_event" + + def test_external_event_awaitable_to_task(self): + """Test converting ExternalEventAwaitable to task.""" + awaitable = ExternalEventAwaitable(self.mock_ctx, "test_event") + + task = awaitable._to_task() + + self.mock_ctx.wait_for_external_event.assert_called_once_with("test_event") + assert isinstance(task, dt_task.Task) + + def test_external_event_awaitable_slots(self): + """Test that ExternalEventAwaitable has __slots__.""" + assert hasattr(ExternalEventAwaitable, "__slots__") + + +class TestWhenAllAwaitable: + """Test WhenAllAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_task1 = Mock(spec=dt_task.Task) + self.mock_task2 = Mock(spec=dt_task.Task) + self.mock_awaitable1 = Mock(spec=AwaitableBase) + self.mock_awaitable1._to_task.return_value = self.mock_task1 + self.mock_awaitable2 = Mock(spec=AwaitableBase) + self.mock_awaitable2._to_task.return_value = self.mock_task2 + + def test_when_all_awaitable_creation(self): + """Test creating a WhenAllAwaitable.""" + awaitables = [self.mock_awaitable1, self.mock_awaitable2] + awaitable = WhenAllAwaitable(awaitables) + + assert awaitable._tasks_like == awaitables + + def test_when_all_awaitable_to_task(self): + """Test converting WhenAllAwaitable to task.""" + awaitables = [self.mock_awaitable1, self.mock_awaitable2] + awaitable = WhenAllAwaitable(awaitables) + + with patch("durabletask.task.when_all") as mock_when_all: + mock_when_all.return_value = Mock(spec=dt_task.Task) + task = awaitable._to_task() + + mock_when_all.assert_called_once_with([self.mock_task1, self.mock_task2]) + assert isinstance(task, dt_task.Task) + + def test_when_all_awaitable_with_tasks(self): + """Test WhenAllAwaitable with direct tasks.""" + tasks = [self.mock_task1, self.mock_task2] + awaitable = WhenAllAwaitable(tasks) + + with patch("durabletask.task.when_all") as mock_when_all: + mock_when_all.return_value = Mock(spec=dt_task.Task) + awaitable._to_task() + + mock_when_all.assert_called_once_with([self.mock_task1, self.mock_task2]) + + def test_when_all_awaitable_slots(self): + """Test that WhenAllAwaitable has __slots__.""" + assert hasattr(WhenAllAwaitable, "__slots__") + + def _drive_awaitable(self, awaitable, result): + gen = awaitable.__await__() + try: + yielded = next(gen) + except StopIteration as si: # empty fast-path + return si.value + assert isinstance(yielded, dt_task.Task) or True # we don't strictly require type here + try: + return gen.send(result) + except StopIteration as si: + return si.value + + def test_when_all_empty_fast_path(self): + awaitable = WhenAllAwaitable([]) + # Should complete without yielding + gen = awaitable.__await__() + with pytest.raises(StopIteration) as si: + next(gen) + assert si.value.value == [] + + def test_when_all_success_and_caching(self): + awaitable = WhenAllAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + results = ["r1", "r2"] + with patch("durabletask.task.when_all") as mock_when_all: + mock_when_all.return_value = Mock(spec=dt_task.Task) + # Simulate runtime returning results list + gen = awaitable.__await__() + _ = next(gen) + with pytest.raises(StopIteration) as si: + gen.send(results) + assert si.value.value == results + # Re-await should return cached without yielding + gen2 = awaitable.__await__() + with pytest.raises(StopIteration) as si2: + next(gen2) + assert si2.value.value == results + + def test_when_all_exception_and_caching(self): + awaitable = WhenAllAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + with patch("durabletask.task.when_all") as mock_when_all: + mock_when_all.return_value = Mock(spec=dt_task.Task) + gen = awaitable.__await__() + _ = next(gen) + + class Boom(Exception): + pass + + with pytest.raises(Boom): + gen.throw(Boom()) + # Re-await should immediately raise cached exception + gen2 = awaitable.__await__() + with pytest.raises(Boom): + next(gen2) + + +class TestWhenAnyAwaitable: + """Test WhenAnyAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_task1 = Mock(spec=dt_task.Task) + self.mock_task2 = Mock(spec=dt_task.Task) + self.mock_awaitable1 = Mock(spec=AwaitableBase) + self.mock_awaitable1._to_task.return_value = self.mock_task1 + self.mock_awaitable2 = Mock(spec=AwaitableBase) + self.mock_awaitable2._to_task.return_value = self.mock_task2 + + def test_when_any_awaitable_creation(self): + """Test creating a WhenAnyAwaitable.""" + awaitables = [self.mock_awaitable1, self.mock_awaitable2] + awaitable = WhenAnyAwaitable(awaitables) + + assert awaitable._originals == awaitables + assert awaitable._underlying is None # Lazy initialization + # Trigger initialization + underlying = awaitable._ensure_underlying() + assert len(underlying) == 2 + + def test_when_any_awaitable_to_task(self): + """Test converting WhenAnyAwaitable to task.""" + awaitables = [self.mock_awaitable1, self.mock_awaitable2] + awaitable = WhenAnyAwaitable(awaitables) + + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + task = awaitable._to_task() + + # Should use cached underlying tasks + assert mock_when_any.call_count >= 1 + assert isinstance(task, dt_task.Task) + + def test_when_any_awaitable_slots(self): + """Test that WhenAnyAwaitable has __slots__.""" + assert hasattr(WhenAnyAwaitable, "__slots__") + + def test_when_any_returns_index_and_result(self): + """Test that when_any returns (index, result) tuple.""" + awaitable = WhenAnyAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen = awaitable.__await__() + _ = next(gen) + # Simulate runtime returning that task1 completed + self.mock_task1.get_result = Mock(return_value="done1") + with pytest.raises(StopIteration) as si: + gen.send(self.mock_task1) + index, result = si.value.value + # Returns index of first task (0) and its result + assert index == 0 + assert result == "done1" + + def test_when_any_second_task_completes(self): + """Test when_any returns correct index when second task completes.""" + awaitable = WhenAnyAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen = awaitable.__await__() + _ = next(gen) + # Simulate runtime returning that task2 completed + self.mock_task2.get_result = Mock(return_value="done2") + with pytest.raises(StopIteration) as si: + gen.send(self.mock_task2) + index, result = si.value.value + # Returns index of second task (1) and its result + assert index == 1 + assert result == "done2" + + def test_when_any_no_coroutine_reuse_on_multiple_awaits(self): + """Test that awaiting the same WhenAnyAwaitable multiple times doesn't cause coroutine reuse errors.""" + awaitable = WhenAnyAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + + # First await + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen1 = awaitable.__await__() + _ = next(gen1) + self.mock_task1.get_result = Mock(return_value="result1") + with pytest.raises(StopIteration) as si1: + gen1.send(self.mock_task1) + index1, result1 = si1.value.value + + # Second await (simulates replay scenario) + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen2 = awaitable.__await__() + _ = next(gen2) + self.mock_task2.get_result = Mock(return_value="result2") + with pytest.raises(StopIteration) as si2: + gen2.send(self.mock_task2) + index2, result2 = si2.value.value + + # Both should succeed without coroutine reuse errors + assert index1 == 0 + assert result1 == "result1" + assert index2 == 1 + assert result2 == "result2" + + def test_when_any_exception_replay_path(self): + """Test that gen.throw() works correctly (simulates exception during replay).""" + awaitable = WhenAnyAwaitable([self.mock_awaitable1, self.mock_awaitable2]) + + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen = awaitable.__await__() + _ = next(gen) + + # Simulate exception being thrown into the generator + test_error = RuntimeError("test exception") + with pytest.raises(RuntimeError) as exc_info: + gen.throw(test_error) + + assert exc_info.value is test_error + + +class TestSwallowExceptionAwaitable: + """Test SwallowExceptionAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_awaitable = Mock(spec=AwaitableBase) + self.mock_task = Mock(spec=dt_task.Task) + self.mock_awaitable._to_task.return_value = self.mock_task + + def test_swallow_exception_awaitable_creation(self): + """Test creating a SwallowExceptionAwaitable.""" + awaitable = SwallowExceptionAwaitable(self.mock_awaitable) + + assert awaitable._awaitable is self.mock_awaitable + + def test_swallow_exception_awaitable_to_task(self): + """Test converting SwallowExceptionAwaitable to task.""" + awaitable = SwallowExceptionAwaitable(self.mock_awaitable) + + task = awaitable._to_task() + + self.mock_awaitable._to_task.assert_called_once() + assert task is self.mock_task + + def test_swallow_exception_awaitable_slots(self): + """Test that SwallowExceptionAwaitable has __slots__.""" + assert hasattr(SwallowExceptionAwaitable, "__slots__") + + def test_swallow_exception_runtime_success_and_failure(self): + awaitable = SwallowExceptionAwaitable(self.mock_awaitable) + # Success path + gen = awaitable.__await__() + _ = next(gen) + with pytest.raises(StopIteration) as si: + gen.send("ok") + assert si.value.value == "ok" + # Failure path returns exception instance via StopIteration.value + awaitable2 = SwallowExceptionAwaitable(self.mock_awaitable) + gen2 = awaitable2.__await__() + _ = next(gen2) + err = RuntimeError("boom") + with pytest.raises(StopIteration) as si2: + gen2.throw(err) + assert si2.value.value is err + + +class TestTimeoutAwaitable: + """Test TimeoutAwaitable functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.create_timer.return_value = Mock(spec=dt_task.Task) + self.mock_awaitable = Mock(spec=AwaitableBase) + self.mock_task = Mock(spec=dt_task.Task) + self.mock_awaitable._to_task.return_value = self.mock_task + + def test_timeout_awaitable_creation(self): + """Test creating a TimeoutAwaitable.""" + awaitable = TimeoutAwaitable(self.mock_awaitable, 5.0, self.mock_ctx) + + assert awaitable._ctx is self.mock_ctx + assert awaitable._awaitable is self.mock_awaitable + assert awaitable._timeout_seconds == 5.0 + + def test_timeout_awaitable_to_task(self): + """Test converting TimeoutAwaitable to task.""" + awaitable = TimeoutAwaitable(self.mock_awaitable, 5.0, self.mock_ctx) + + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + task = awaitable._to_task() + + # Should create timer and call when_any + self.mock_ctx.create_timer.assert_called_once() + self.mock_awaitable._to_task.assert_called_once() + mock_when_any.assert_called_once() + assert isinstance(task, dt_task.Task) + + def test_timeout_awaitable_slots(self): + """Test that TimeoutAwaitable has __slots__.""" + assert hasattr(TimeoutAwaitable, "__slots__") + + def test_timeout_awaitable_timeout_hits(self): + awaitable = TimeoutAwaitable(self.mock_awaitable, 5.0, self.mock_ctx) + # Capture the cached timeout task instance created by _to_task + gen = awaitable.__await__() + _ = next(gen) + timeout_task = awaitable._timeout_task + assert timeout_task is not None + with pytest.raises(WorkflowTimeoutError): + gen.send(timeout_task) + + def test_timeout_awaitable_operation_completes_first(self): + awaitable = TimeoutAwaitable(self.mock_awaitable, 5.0, self.mock_ctx) + gen = awaitable.__await__() + _ = next(gen) + # If the operation completed first, runtime returns the operation task + self.mock_task.result = "value" + with pytest.raises(StopIteration) as si: + gen.send(self.mock_task) + assert si.value.value == "value" + + def test_timeout_awaitable_non_task_sentinel_heuristic(self): + awaitable = TimeoutAwaitable(self.mock_awaitable, 5.0, self.mock_ctx) + gen = awaitable.__await__() + _ = next(gen) + with pytest.raises(StopIteration) as si: + gen.send({"x": 1}) + assert si.value.value == {"x": 1} + + +class TestPropagationForActivityAndSubOrch: + """Test propagation of app_id/metadata/retry_policy to context methods via signature detection.""" + + class _CtxWithSignatures: + def __init__(self): + self.call_activity_called_with = None + self.call_sub_orchestrator_called_with = None + + def call_activity( + self, activity_fn, *, input=None, retry_policy=None, app_id=None, metadata=None + ): + self.call_activity_called_with = dict( + activity_fn=activity_fn, + input=input, + retry_policy=retry_policy, + app_id=app_id, + metadata=metadata, + ) + return dt_task.CompletableTask() + + def call_sub_orchestrator( + self, + workflow_fn, + *, + input=None, + instance_id=None, + retry_policy=None, + app_id=None, + metadata=None, + ): + self.call_sub_orchestrator_called_with = dict( + workflow_fn=workflow_fn, + input=input, + instance_id=instance_id, + retry_policy=retry_policy, + app_id=app_id, + metadata=metadata, + ) + return dt_task.CompletableTask() + + def test_activity_propagation_app_id_metadata_retry(self): + ctx = self._CtxWithSignatures() + activity_fn = lambda: None # noqa: E731 + rp = dt_task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), max_number_of_attempts=2 + ) + awaitable = ActivityAwaitable( + ctx, activity_fn, input={"a": 1}, retry_policy=rp, app_id="app-x", metadata={"h": "v"} + ) + _ = awaitable._to_task() + called = ctx.call_activity_called_with + assert called["activity_fn"] is activity_fn + assert called["input"] == {"a": 1} + assert called["retry_policy"] is rp + assert called["app_id"] == "app-x" + assert called["metadata"] == {"h": "v"} + + def test_suborch_propagation_all_fields(self): + ctx = self._CtxWithSignatures() + workflow_fn = lambda: None # noqa: E731 + rp = dt_task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), max_number_of_attempts=2 + ) + awaitable = SubOrchestratorAwaitable( + ctx, + workflow_fn, + input=123, + instance_id="iid-1", + retry_policy=rp, + app_id="app-y", + metadata={"k": "m"}, + ) + _ = awaitable._to_task() + called = ctx.call_sub_orchestrator_called_with + assert called["workflow_fn"] is workflow_fn + assert called["input"] == 123 + assert called["instance_id"] == "iid-1" + assert called["retry_policy"] is rp + assert called["app_id"] == "app-y" + assert called["metadata"] == {"k": "m"} + + +class TestExternalEventIntegration: + """Integration-like tests combining ExternalEventAwaitable with when_any/timeout wrappers.""" + + def setup_method(self): + self.ctx = Mock() + # Provide stable task instances for mapping + self.event_task = Mock(spec=dt_task.Task) + self.timer_task = Mock(spec=dt_task.Task) + self.ctx.wait_for_external_event.return_value = self.event_task + self.ctx.create_timer.return_value = self.timer_task + + def test_when_any_between_event_and_timer_event_wins(self): + event_aw = ExternalEventAwaitable(self.ctx, "ev") + timer_aw = SleepAwaitable(self.ctx, 1.0) + wa = WhenAnyAwaitable([event_aw, timer_aw]) + with patch("durabletask.task.when_any") as mock_when_any: + mock_when_any.return_value = Mock(spec=dt_task.Task) + gen = wa.__await__() + _ = next(gen) + with pytest.raises(StopIteration) as si: + gen.send(self.event_task) + index, result = si.value.value + assert index == 0 + + def test_timeout_wrapper_times_out_before_event(self): + event_aw = ExternalEventAwaitable(self.ctx, "ev") + tw = TimeoutAwaitable(event_aw, 2.0, self.ctx) + gen = tw.__await__() + _ = next(gen) + # Should have cached timeout task equal to ctx.create_timer return + assert tw._timeout_task is self.timer_task + with pytest.raises(WorkflowTimeoutError): + gen.send(self.timer_task) + + +class TestAwaitableSlots: + """Test that all awaitable classes use __slots__ for performance.""" + + def test_all_awaitables_have_slots(self): + """Test that all awaitable classes have __slots__.""" + awaitable_classes = [ + AwaitableBase, + ActivityAwaitable, + SubOrchestratorAwaitable, + SleepAwaitable, + ExternalEventAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, + SwallowExceptionAwaitable, + TimeoutAwaitable, + ] + + for cls in awaitable_classes: + assert hasattr(cls, "__slots__"), f"{cls.__name__} should have __slots__" diff --git a/tests/aio/test_ci_compatibility.py b/tests/aio/test_ci_compatibility.py new file mode 100644 index 0000000..41a39fb --- /dev/null +++ b/tests/aio/test_ci_compatibility.py @@ -0,0 +1,239 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CI/CD compatibility tests for AsyncWorkflowContext. + +These tests are designed to be run in continuous integration to catch +compatibility regressions early and ensure smooth upstream merges. +""" + +import pytest + +from durabletask.aio import AsyncWorkflowContext + +from .compatibility_utils import ( + CompatibilityChecker, + check_async_context_compatibility, + validate_runtime_compatibility, +) + + +class TestCICompatibility: + """CI/CD compatibility validation tests.""" + + def test_async_context_maintains_full_compatibility(self): + """ + Critical test: Ensure AsyncWorkflowContext maintains full compatibility. + + This test should NEVER fail in CI. If it does, it indicates a breaking + change that could cause issues with upstream merges or existing code. + """ + report = check_async_context_compatibility() + + assert report["is_compatible"], ( + f"CRITICAL: AsyncWorkflowContext compatibility broken! " + f"Missing members: {report['missing_members']}" + ) + + # Ensure we have no missing members + assert len(report["missing_members"]) == 0, ( + f"Missing required members: {report['missing_members']}" + ) + + def test_no_regression_in_base_interface(self): + """ + Test that no base OrchestrationContext interface members are missing. + + This catches regressions where required properties or methods are + accidentally removed or renamed. + """ + from unittest.mock import Mock + + import durabletask.task as dt_task + + # Create a test instance + mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + mock_base_ctx.instance_id = "ci-test" + mock_base_ctx.current_utc_datetime = None + mock_base_ctx.is_replaying = False + mock_base_ctx.is_suspended = False + mock_base_ctx.workflow_name = "test-workflow" + + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # Validate runtime compatibility + missing_items = CompatibilityChecker.validate_context_compatibility(async_ctx) + + assert len(missing_items) == 0, ( + f"Compatibility regression detected! Missing: {missing_items}" + ) + + def test_runtime_validation_passes(self): + """ + Test that runtime validation passes for AsyncWorkflowContext. + + This ensures the context can be used wherever OrchestrationContext + is expected without runtime errors. + """ + from unittest.mock import Mock + + import durabletask.task as dt_task + + mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + mock_base_ctx.instance_id = "runtime-test" + mock_base_ctx.current_utc_datetime = None + mock_base_ctx.is_replaying = False + mock_base_ctx.is_suspended = False + mock_base_ctx.workflow_name = "test-workflow" + + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # This should pass without warnings or errors + is_valid = validate_runtime_compatibility(async_ctx, strict=True) + assert is_valid, "Runtime validation failed" + + def test_enhanced_methods_are_additive_only(self): + """ + Test that enhanced methods are purely additive and don't break base functionality. + + This ensures that new async-specific methods don't interfere with + the base OrchestrationContext interface. + """ + report = check_async_context_compatibility() + + # We should have extra methods (enhancements) but no missing ones + assert len(report["extra_members"]) > 0, "No enhanced methods found" + assert len(report["missing_members"]) == 0, "Base methods are missing" + + # Verify some expected enhancements exist + extra_methods = [ + item.split(": ")[1] for item in report["extra_members"] if "method:" in item + ] + expected_enhancements = ["when_all", "when_any"] + + for enhancement in expected_enhancements: + assert enhancement in extra_methods, f"Expected enhancement '{enhancement}' not found" + + def test_protocol_compliance_at_class_level(self): + """ + Test that AsyncWorkflowContext class complies with the protocol. + + This is a compile-time style check that validates the class structure + without needing to instantiate it. + """ + is_compliant = CompatibilityChecker.check_protocol_compliance(AsyncWorkflowContext) + assert is_compliant, ( + "AsyncWorkflowContext does not comply with OrchestrationContextProtocol" + ) + + @pytest.mark.parametrize( + "property_name", + [ + "instance_id", + "current_utc_datetime", + "is_replaying", + "workflow_name", + "is_suspended", + ], + ) + def test_required_property_exists(self, property_name): + """ + Test that each required property exists on AsyncWorkflowContext. + + This parameterized test ensures all OrchestrationContext properties + are available on AsyncWorkflowContext. + """ + assert hasattr(AsyncWorkflowContext, property_name), ( + f"Required property '{property_name}' missing from AsyncWorkflowContext" + ) + + @pytest.mark.parametrize( + "method_name", + [ + "set_custom_status", + "create_timer", + "call_activity", + "call_sub_orchestrator", + "wait_for_external_event", + "continue_as_new", + ], + ) + def test_required_method_exists(self, method_name): + """ + Test that each required method exists on AsyncWorkflowContext. + + This parameterized test ensures all OrchestrationContext methods + are available on AsyncWorkflowContext. + """ + assert hasattr(AsyncWorkflowContext, method_name), ( + f"Required method '{method_name}' missing from AsyncWorkflowContext" + ) + + method = getattr(AsyncWorkflowContext, method_name) + assert callable(method), f"Required method '{method_name}' is not callable" + + +class TestUpstreamMergeReadiness: + """Tests to ensure readiness for upstream merges.""" + + def test_no_breaking_changes_in_public_api(self): + """ + Test that the public API hasn't changed in breaking ways. + + This test helps ensure that upstream merges won't break existing + code that depends on AsyncWorkflowContext. + """ + report = check_async_context_compatibility() + + # Should have all base members + base_member_count = len(report["base_members"]) + context_member_count = len(report["context_members"]) + + assert context_member_count >= base_member_count, ( + "AsyncWorkflowContext has fewer members than OrchestrationContext" + ) + + def test_backward_compatibility_maintained(self): + """ + Test that backward compatibility is maintained. + + This ensures that code written against the base OrchestrationContext + interface will continue to work with AsyncWorkflowContext. + """ + from unittest.mock import Mock + + import durabletask.task as dt_task + + mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + mock_base_ctx.instance_id = "backward-compat-test" + mock_base_ctx.current_utc_datetime = None + mock_base_ctx.is_replaying = False + + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # Test that it can be used in functions expecting OrchestrationContext + def function_expecting_base_context(ctx): + # This should work without any issues + return { + "id": ctx.instance_id, + "replaying": ctx.is_replaying, + "time": ctx.current_utc_datetime, + } + + # This should not raise any errors + result = function_expecting_base_context(async_ctx) + assert result["id"] == "backward-compat-test" + assert result["replaying"] is False + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/aio/test_context.py b/tests/aio/test_context.py new file mode 100644 index 0000000..6625f41 --- /dev/null +++ b/tests/aio/test_context.py @@ -0,0 +1,395 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for AsyncWorkflowContext in durabletask.aio. +""" + +import random +import uuid +from datetime import datetime, timedelta +from unittest.mock import Mock + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + ActivityAwaitable, + AsyncWorkflowContext, + AwaitableBase, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + TimeoutAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + + +class TestAsyncWorkflowContext: + """Test AsyncWorkflowContext functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + + # Mock methods + self.mock_base_ctx.call_activity.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_sub_orchestrator.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.create_timer.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.wait_for_external_event.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.set_custom_status = Mock() + self.mock_base_ctx.continue_as_new = Mock() + + self.ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_context_creation(self): + """Test creating AsyncWorkflowContext.""" + assert self.ctx._base_ctx is self.mock_base_ctx + assert isinstance(self.ctx._operation_history, list) + assert isinstance(self.ctx._cleanup_tasks, list) + + def test_instance_id_property(self): + """Test instance_id property.""" + assert self.ctx.instance_id == "test-instance-123" + + def test_current_utc_datetime_property(self): + """Test current_utc_datetime property.""" + assert self.ctx.current_utc_datetime == datetime(2023, 1, 1, 12, 0, 0) + + def test_is_replaying_property(self): + """Test is_replaying property.""" + assert self.ctx.is_replaying == False + + self.mock_base_ctx.is_replaying = True + assert self.ctx.is_replaying == True + + def test_is_suspended_property(self): + """Test is_suspended property.""" + assert self.ctx.is_suspended == False + + self.mock_base_ctx.is_suspended = True + assert self.ctx.is_suspended == True + + def test_now_method(self): + """Test now() method from DeterministicContextMixin.""" + now = self.ctx.now() + assert now == datetime(2023, 1, 1, 12, 0, 0) + assert now is self.ctx.current_utc_datetime + + def test_random_method(self): + """Test random() method from DeterministicContextMixin.""" + rng = self.ctx.random() + assert isinstance(rng, random.Random) + + # Should be deterministic + rng1 = self.ctx.random() + rng2 = self.ctx.random() + + val1 = rng1.random() + val2 = rng2.random() + assert val1 == val2 # Same seed should produce same values + + def test_uuid4_method(self): + """Test uuid4() method from DeterministicContextMixin.""" + test_uuid = self.ctx.uuid4() + assert isinstance(test_uuid, uuid.UUID) + assert test_uuid.version == 5 # Now using UUID v5 for .NET compatibility + + # Should increment counter - each call produces different UUID + uuid1 = self.ctx.uuid4() + uuid2 = self.ctx.uuid4() + assert uuid1 != uuid2 # Counter increments + + def test_new_guid_method(self): + """Test new_guid() alias method.""" + guid = self.ctx.new_guid() + assert isinstance(guid, uuid.UUID) + assert guid.version == 5 # Now using UUID v5 for .NET compatibility + + def test_random_string_method(self): + """Test random_string() method from DeterministicContextMixin.""" + # Test default alphabet + s1 = self.ctx.random_string(10) + assert len(s1) == 10 + assert all(c.isalnum() for c in s1) + + # Test custom alphabet + s2 = self.ctx.random_string(5, alphabet="ABC") + assert len(s2) == 5 + assert all(c in "ABC" for c in s2) + + # Test deterministic behavior + s3 = self.ctx.random_string(10) + assert s1 == s3 # Same context should produce same string + + def test_call_activity_method(self): + """Test call_activity() method.""" + activity_fn = Mock(__name__="test_activity") + + # Basic call + awaitable = self.ctx.call_activity(activity_fn, input="test_input") + + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._activity_fn is activity_fn + assert awaitable._input == "test_input" + assert awaitable._retry_policy is None + assert awaitable._metadata is None + + def test_call_activity_with_retry_policy(self): + """Test call_activity() with retry policy.""" + activity_fn = Mock(__name__="test_activity") + retry_policy = Mock() + + awaitable = self.ctx.call_activity( + activity_fn, input="test_input", retry_policy=retry_policy + ) + + assert awaitable._retry_policy is retry_policy + + def test_call_activity_with_metadata(self): + """Test call_activity() with metadata.""" + activity_fn = Mock(__name__="test_activity") + metadata = {"key": "value"} + + awaitable = self.ctx.call_activity(activity_fn, input="test_input", metadata=metadata) + + assert awaitable._metadata == metadata + + def test_call_sub_orchestrator_method(self): + """Test call_sub_orchestrator() method.""" + workflow_fn = Mock(__name__="test_workflow") + + awaitable = self.ctx.call_sub_orchestrator( + workflow_fn, input="test_input", instance_id="sub-instance" + ) + + assert isinstance(awaitable, SubOrchestratorAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._workflow_fn is workflow_fn + assert awaitable._input == "test_input" + assert awaitable._instance_id == "sub-instance" + + def test_create_timer_method(self): + """Test create_timer() method.""" + # Test with timedelta + duration = timedelta(seconds=30) + awaitable = self.ctx.create_timer(duration) + + assert isinstance(awaitable, SleepAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._duration is duration + + def test_sleep_method(self): + """Test sleep() method.""" + # Test with float + awaitable = self.ctx.create_timer(5.0) + + assert isinstance(awaitable, SleepAwaitable) + assert awaitable._duration == 5.0 + + # Test with timedelta + duration = timedelta(minutes=1) + awaitable = self.ctx.create_timer(duration) + assert awaitable._duration is duration + + # Test with datetime + deadline = datetime(2023, 1, 1, 13, 0, 0) + awaitable = self.ctx.create_timer(deadline) + assert awaitable._duration is deadline + + def test_wait_for_external_event_method(self): + """Test wait_for_external_event() method.""" + awaitable = self.ctx.wait_for_external_event("test_event") + + assert isinstance(awaitable, ExternalEventAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._name == "test_event" + + def test_when_all_method(self): + """Test when_all() method.""" + # Create mock awaitables + awaitable1 = Mock() + awaitable2 = Mock() + awaitables = [awaitable1, awaitable2] + + result = self.ctx.when_all(awaitables) + + assert isinstance(result, WhenAllAwaitable) + assert result._tasks_like == awaitables + + def test_when_any_method(self): + """Test when_any() method.""" + awaitable1 = Mock(spec=AwaitableBase) + awaitable1._to_task.return_value = Mock(spec=dt_task.Task) + awaitable2 = Mock(spec=AwaitableBase) + awaitable2._to_task.return_value = Mock(spec=dt_task.Task) + awaitables = [awaitable1, awaitable2] + + result = self.ctx.when_any(awaitables) + + assert isinstance(result, WhenAnyAwaitable) + assert result._originals == awaitables + + def test_with_timeout_method(self): + """Test with_timeout() method.""" + mock_awaitable = Mock() + + result = self.ctx.with_timeout(mock_awaitable, 5.0) + + assert isinstance(result, TimeoutAwaitable) + assert result._awaitable is mock_awaitable + assert result._timeout_seconds == 5.0 + assert result._ctx is self.mock_base_ctx + + def test_set_custom_status_method(self): + """Test set_custom_status() method.""" + self.ctx.set_custom_status("Processing data") + + self.mock_base_ctx.set_custom_status.assert_called_once_with("Processing data") + + def test_set_custom_status_not_supported(self): + """Test set_custom_status() when not supported by base context.""" + # Remove the method to simulate unsupported base context + del self.mock_base_ctx.set_custom_status + + # Should not raise error + self.ctx.set_custom_status("test") + + def test_continue_as_new_method(self): + """Test continue_as_new() method.""" + new_input = {"restart": True} + + self.ctx.continue_as_new(new_input, save_events=True) + + self.mock_base_ctx.continue_as_new.assert_called_once_with(new_input, save_events=True) + + def test_debug_mode_enabled(self): + """Test debug mode functionality.""" + import os + from unittest.mock import patch + + # Test with DAPR_WF_DEBUG + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + debug_ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert debug_ctx._debug_mode == True + + # Test with DT_DEBUG + with patch.dict(os.environ, {"DT_DEBUG": "true"}): + debug_ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert debug_ctx._debug_mode == True + + def test_operation_logging_in_debug_mode(self): + """Test that operations are logged in debug mode.""" + import os + from unittest.mock import patch + + with patch.dict(os.environ, {"DAPR_WF_DEBUG": "true"}): + debug_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Perform some operations + debug_ctx.call_activity("test_activity", input="test") + debug_ctx.create_timer(5.0) + debug_ctx.wait_for_external_event("test_event") + + # Should have logged operations + assert len(debug_ctx._operation_history) == 3 + + # Check operation details + ops = debug_ctx._operation_history + assert ops[0]["type"] == "activity" + assert ops[1]["type"] == "sleep" + assert ops[2]["type"] == "wait_for_external_event" + + def test_get_debug_info_method(self): + """Test get_debug_info() method.""" + debug_info = self.ctx._get_info_snapshot() + + assert isinstance(debug_info, dict) + assert debug_info["instance_id"] == "test-instance-123" + assert debug_info["is_replaying"] == False + assert "operation_history" in debug_info + assert "cleanup_tasks_count" in debug_info + + def test_detection_disabled_property(self): + """Test _detection_disabled property.""" + import os + from unittest.mock import patch + + # Test with environment variable + with patch.dict(os.environ, {"DAPR_WF_DISABLE_DETERMINISTIC_DETECTION": "true"}): + disabled_ctx = AsyncWorkflowContext(self.mock_base_ctx) + assert disabled_ctx._detection_disabled == True + + # Test without environment variable + assert self.ctx._detection_disabled == False + + def test_workflow_name_tracking(self): + """Test workflow name tracking.""" + # Should start as None + assert self.ctx._workflow_name is None + + # Can be set + self.ctx._workflow_name = "test_workflow" + assert self.ctx._workflow_name == "test_workflow" + + def test_current_step_tracking(self): + """Test current step tracking.""" + # Should start as None + assert self.ctx._current_step is None + + # Can be set + self.ctx._current_step = "step_1" + assert self.ctx._current_step == "step_1" + + def test_context_slots(self): + """Test that AsyncWorkflowContext uses __slots__.""" + assert hasattr(AsyncWorkflowContext, "__slots__") + + def test_deterministic_context_mixin_integration(self): + """Test integration with DeterministicContextMixin.""" + from durabletask.deterministic import DeterministicContextMixin + + # Should be an instance of the mixin + assert isinstance(self.ctx, DeterministicContextMixin) + + # Should have all mixin methods + assert hasattr(self.ctx, "now") + assert hasattr(self.ctx, "random") + assert hasattr(self.ctx, "uuid4") + assert hasattr(self.ctx, "new_guid") + assert hasattr(self.ctx, "random_string") + + def test_context_with_string_activity_name(self): + """Test context methods with string activity/workflow names.""" + # Test with string activity name + awaitable = self.ctx.call_activity("string_activity_name", input="test") + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._activity_fn == "string_activity_name" + + # Test with string workflow name + awaitable = self.ctx.call_sub_orchestrator("string_workflow_name", input="test") + assert isinstance(awaitable, SubOrchestratorAwaitable) + assert awaitable._workflow_fn == "string_workflow_name" + + def test_context_method_parameter_validation(self): + """Test parameter validation in context methods.""" + # Test random_string with invalid parameters + with pytest.raises(ValueError): + self.ctx.random_string(-1) # Negative length + + with pytest.raises(ValueError): + self.ctx.random_string(5, alphabet="") # Empty alphabet diff --git a/tests/aio/test_context_compatibility.py b/tests/aio/test_context_compatibility.py new file mode 100644 index 0000000..00a45fa --- /dev/null +++ b/tests/aio/test_context_compatibility.py @@ -0,0 +1,338 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compatibility tests to ensure AsyncWorkflowContext maintains API compatibility +with the base OrchestrationContext interface. + +This test suite validates that AsyncWorkflowContext provides all the properties +and methods expected by the base OrchestrationContext, helping prevent regressions +and ensuring smooth upstream merges. +""" + +import inspect +from datetime import datetime, timedelta +from unittest.mock import Mock + +import pytest + +from durabletask import task +from durabletask.aio import AsyncWorkflowContext + + +class TestAsyncWorkflowContextCompatibility: + """Test suite to validate AsyncWorkflowContext compatibility with OrchestrationContext.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.workflow_name = "test_workflow" + self.mock_base_ctx.is_suspended = False + + self.async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_all_orchestration_context_properties_exist(self): + """Test that AsyncWorkflowContext has all properties from OrchestrationContext.""" + # Get all properties from OrchestrationContext + orchestration_properties = [] + for name, member in inspect.getmembers(task.OrchestrationContext): + if isinstance(member, property) and not name.startswith("_"): + orchestration_properties.append(name) + + # Check that AsyncWorkflowContext has all these properties + for prop_name in orchestration_properties: + assert hasattr(self.async_ctx, prop_name), ( + f"AsyncWorkflowContext is missing property: {prop_name}" + ) + + # Verify the property is actually callable (not just an attribute) + prop_value = getattr(self.async_ctx, prop_name) + assert prop_value is not None, f"Property {prop_name} returned None unexpectedly" + + def test_all_orchestration_context_methods_exist(self): + """Test that AsyncWorkflowContext has all methods from OrchestrationContext.""" + # Get all abstract methods from OrchestrationContext + orchestration_methods = [] + for name, member in inspect.getmembers(task.OrchestrationContext): + if (inspect.isfunction(member) or inspect.ismethod(member)) and not name.startswith( + "_" + ): + orchestration_methods.append(name) + + # Check that AsyncWorkflowContext has all these methods + for method_name in orchestration_methods: + assert hasattr(self.async_ctx, method_name), ( + f"AsyncWorkflowContext is missing method: {method_name}" + ) + + # Verify the method is callable + method = getattr(self.async_ctx, method_name) + assert callable(method), f"Method {method_name} is not callable" + + def test_property_compatibility_instance_id(self): + """Test instance_id property compatibility.""" + assert self.async_ctx.instance_id == "test-instance-123" + assert isinstance(self.async_ctx.instance_id, str) + + def test_property_compatibility_current_utc_datetime(self): + """Test current_utc_datetime property compatibility.""" + assert self.async_ctx.current_utc_datetime == datetime(2023, 1, 1, 12, 0, 0) + assert isinstance(self.async_ctx.current_utc_datetime, datetime) + + def test_property_compatibility_is_replaying(self): + """Test is_replaying property compatibility.""" + assert self.async_ctx.is_replaying is False + assert isinstance(self.async_ctx.is_replaying, bool) + + def test_property_compatibility_workflow_name(self): + """Test workflow_name property compatibility.""" + assert self.async_ctx.workflow_name == "test_workflow" + assert isinstance(self.async_ctx.workflow_name, (str, type(None))) + + def test_property_compatibility_is_suspended(self): + """Test is_suspended property compatibility.""" + assert self.async_ctx.is_suspended is False + assert isinstance(self.async_ctx.is_suspended, bool) + + def test_method_compatibility_set_custom_status(self): + """Test set_custom_status method compatibility.""" + # Test that method exists and can be called + self.async_ctx.set_custom_status({"status": "running"}) + self.mock_base_ctx.set_custom_status.assert_called_once_with({"status": "running"}) + + def test_method_compatibility_create_timer(self): + """Test create_timer method compatibility.""" + # Mock the return value + mock_task = Mock(spec=task.Task) + self.mock_base_ctx.create_timer.return_value = mock_task + + # Test with timedelta + timer_awaitable = self.async_ctx.create_timer(timedelta(seconds=30)) + assert timer_awaitable is not None + + # Test with datetime + future_time = datetime(2023, 1, 1, 13, 0, 0) + timer_awaitable2 = self.async_ctx.create_timer(future_time) + assert timer_awaitable2 is not None + + def test_method_compatibility_call_activity(self): + """Test call_activity method compatibility.""" + + def test_activity(input_data): + return f"processed: {input_data}" + + activity_awaitable = self.async_ctx.call_activity(test_activity, input="test") + assert activity_awaitable is not None + + # Test with retry policy + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), max_number_of_attempts=3 + ) + activity_awaitable2 = self.async_ctx.call_activity( + test_activity, input="test", retry_policy=retry_policy + ) + assert activity_awaitable2 is not None + + def test_method_compatibility_call_sub_orchestrator(self): + """Test call_sub_orchestrator method compatibility.""" + + async def test_orchestrator(ctx, input_data): + return f"orchestrated: {input_data}" + + sub_orch_awaitable = self.async_ctx.call_sub_orchestrator(test_orchestrator, input="test") + assert sub_orch_awaitable is not None + + # Test with instance_id and retry_policy + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=2), max_number_of_attempts=2 + ) + sub_orch_awaitable2 = self.async_ctx.call_sub_orchestrator( + test_orchestrator, input="test", instance_id="sub-123", retry_policy=retry_policy + ) + assert sub_orch_awaitable2 is not None + + def test_method_compatibility_wait_for_external_event(self): + """Test wait_for_external_event method compatibility.""" + event_awaitable = self.async_ctx.wait_for_external_event("test_event") + assert event_awaitable is not None + + def test_method_compatibility_continue_as_new(self): + """Test continue_as_new method compatibility.""" + # Test basic call + self.async_ctx.continue_as_new({"new": "input"}) + self.mock_base_ctx.continue_as_new.assert_called_once_with({"new": "input"}) + + def test_method_signature_compatibility(self): + """Test that method signatures are compatible with OrchestrationContext.""" + # Get method signatures from both classes + base_methods = {} + for name, method in inspect.getmembers( + task.OrchestrationContext, predicate=inspect.isfunction + ): + if not name.startswith("_"): + base_methods[name] = inspect.signature(method) + + async_methods = {} + for name, method in inspect.getmembers(AsyncWorkflowContext, predicate=inspect.ismethod): + if not name.startswith("_") and name in base_methods: + async_methods[name] = inspect.signature(method) + + # Compare signatures (allowing for additional parameters in async version) + for method_name, base_sig in base_methods.items(): + if method_name in async_methods: + async_sig = async_methods[method_name] + + # Check that all base parameters exist in async version + base_params = list(base_sig.parameters.keys()) + async_params = list(async_sig.parameters.keys()) + + # Skip 'self' parameter for comparison + if "self" in base_params: + base_params.remove("self") + if "self" in async_params: + async_params.remove("self") + + # Async version can have additional parameters, but must have all base ones + for param in base_params: + assert param in async_params or param == "self", ( + f"Method {method_name} missing parameter {param} in AsyncWorkflowContext" + ) + + def test_return_type_compatibility(self): + """Test that methods return compatible types.""" + + # Test that activity calls return awaitables + def test_activity(): + return "result" + + activity_result = self.async_ctx.call_activity(test_activity) + assert hasattr(activity_result, "__await__"), "call_activity should return an awaitable" + + # Test that timer calls return awaitables + timer_result = self.async_ctx.create_timer(timedelta(seconds=1)) + assert hasattr(timer_result, "__await__"), "create_timer should return an awaitable" + + # Test that external event calls return awaitables + event_result = self.async_ctx.wait_for_external_event("test") + assert hasattr(event_result, "__await__"), ( + "wait_for_external_event should return an awaitable" + ) + + def test_async_context_additional_methods(self): + """Test that AsyncWorkflowContext provides additional async-specific methods.""" + # These are enhancements that don't exist in base OrchestrationContext + additional_methods = [ + "sub_orchestrator", # Alias for call_sub_orchestrator + "when_all", # Concurrency primitive + "when_any", # Concurrency primitive (returns tuple) + "with_timeout", # Timeout wrapper + "now", # Deterministic datetime (from mixin) + "random", # Deterministic random (from mixin) + "uuid4", # Deterministic UUID (from mixin) + "new_guid", # Alias for uuid4 + "random_string", # Deterministic string generation + "_get_info_snapshot", # Debug information + ] + + for method_name in additional_methods: + assert hasattr(self.async_ctx, method_name), ( + f"AsyncWorkflowContext missing enhanced method: {method_name}" + ) + + method = getattr(self.async_ctx, method_name) + assert callable(method), f"Enhanced method {method_name} is not callable" + + def test_async_context_manager_compatibility(self): + """Test that AsyncWorkflowContext supports async context manager protocol.""" + assert hasattr(self.async_ctx, "__aenter__"), ( + "AsyncWorkflowContext should support async context manager (__aenter__)" + ) + assert hasattr(self.async_ctx, "__aexit__"), ( + "AsyncWorkflowContext should support async context manager (__aexit__)" + ) + + def test_property_delegation_to_base_context(self): + """Test that properties correctly delegate to the base context.""" + # Change base context values and verify async context reflects them + self.mock_base_ctx.instance_id = "new-instance-456" + assert self.async_ctx.instance_id == "new-instance-456" + + new_time = datetime(2023, 6, 15, 10, 30, 0) + self.mock_base_ctx.current_utc_datetime = new_time + assert self.async_ctx.current_utc_datetime == new_time + + self.mock_base_ctx.is_replaying = True + assert self.async_ctx.is_replaying is True + + def test_method_delegation_to_base_context(self): + """Test that methods correctly delegate to the base context.""" + # Test set_custom_status delegation + self.async_ctx.set_custom_status("test_status") + self.mock_base_ctx.set_custom_status.assert_called_with("test_status") + + # Test continue_as_new delegation + self.async_ctx.continue_as_new("new_input") + self.mock_base_ctx.continue_as_new.assert_called_with("new_input") + + +class TestOrchestrationContextProtocolCompliance: + """Test that AsyncWorkflowContext can be used wherever OrchestrationContext is expected.""" + + def test_async_context_is_orchestration_context_compatible(self): + """Test that AsyncWorkflowContext can be used as OrchestrationContext.""" + mock_base_ctx = Mock(spec=task.OrchestrationContext) + mock_base_ctx.instance_id = "test-123" + mock_base_ctx.current_utc_datetime = datetime.now() + mock_base_ctx.is_replaying = False + + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # Test that it can be used in functions expecting OrchestrationContext + def function_expecting_orchestration_context(ctx: task.OrchestrationContext) -> str: + return f"Instance: {ctx.instance_id}, Replaying: {ctx.is_replaying}" + + # This should work without type errors + result = function_expecting_orchestration_context(async_ctx) + assert "test-123" in result + assert "False" in result + + def test_duck_typing_compatibility(self): + """Test that AsyncWorkflowContext satisfies duck typing for OrchestrationContext.""" + mock_base_ctx = Mock(spec=task.OrchestrationContext) + mock_base_ctx.instance_id = "duck-test" + mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1) + mock_base_ctx.is_replaying = False + mock_base_ctx.workflow_name = "duck_workflow" + + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + # Test all the key properties and methods that would be used in duck typing + assert hasattr(async_ctx, "instance_id") + assert hasattr(async_ctx, "current_utc_datetime") + assert hasattr(async_ctx, "is_replaying") + assert hasattr(async_ctx, "call_activity") + assert hasattr(async_ctx, "call_sub_orchestrator") + assert hasattr(async_ctx, "create_timer") + assert hasattr(async_ctx, "wait_for_external_event") + assert hasattr(async_ctx, "set_custom_status") + assert hasattr(async_ctx, "continue_as_new") + + # Test that they return the expected types + assert isinstance(async_ctx.instance_id, str) + assert isinstance(async_ctx.current_utc_datetime, datetime) + assert isinstance(async_ctx.is_replaying, bool) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/aio/test_context_simple.py b/tests/aio/test_context_simple.py new file mode 100644 index 0000000..7e2edbd --- /dev/null +++ b/tests/aio/test_context_simple.py @@ -0,0 +1,372 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simplified tests for AsyncWorkflowContext in durabletask.aio. + +These tests focus on the actual implementation rather than expected features. +""" + +import asyncio +import random +import uuid +from datetime import datetime, timedelta +from unittest.mock import Mock + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + ActivityAwaitable, + AsyncWorkflowContext, + AwaitableBase, + ExternalEventAwaitable, + SleepAwaitable, + SubOrchestratorAwaitable, + TimeoutAwaitable, + WhenAllAwaitable, + WhenAnyAwaitable, +) + + +class TestAsyncWorkflowContextBasic: + """Test basic AsyncWorkflowContext functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + + # Mock methods that might exist + self.mock_base_ctx.call_activity.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_sub_orchestrator.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.create_timer.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.wait_for_external_event.return_value = Mock(spec=dt_task.Task) + + self.ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_context_creation(self): + """Test creating AsyncWorkflowContext.""" + assert self.ctx._base_ctx is self.mock_base_ctx + + def test_instance_id_property(self): + """Test instance_id property.""" + assert self.ctx.instance_id == "test-instance-123" + + def test_current_utc_datetime_property(self): + """Test current_utc_datetime property.""" + assert self.ctx.current_utc_datetime == datetime(2023, 1, 1, 12, 0, 0) + + def test_is_replaying_property(self): + """Test is_replaying property.""" + assert self.ctx.is_replaying == False + + self.mock_base_ctx.is_replaying = True + assert self.ctx.is_replaying == True + + def test_is_suspended_property(self): + """Test is_suspended property.""" + assert self.ctx.is_suspended == False + + self.mock_base_ctx.is_suspended = True + assert self.ctx.is_suspended == True + + def test_now_method(self): + """Test now() method from DeterministicContextMixin.""" + now = self.ctx.now() + assert now == datetime(2023, 1, 1, 12, 0, 0) + assert now is self.ctx.current_utc_datetime + + def test_random_method(self): + """Test random() method from DeterministicContextMixin.""" + rng = self.ctx.random() + assert isinstance(rng, random.Random) + + # Should be deterministic + rng1 = self.ctx.random() + rng2 = self.ctx.random() + + val1 = rng1.random() + val2 = rng2.random() + assert val1 == val2 # Same seed should produce same values + + def test_uuid4_method(self): + """Test uuid4() method from DeterministicContextMixin.""" + test_uuid = self.ctx.uuid4() + assert isinstance(test_uuid, uuid.UUID) + assert test_uuid.version == 5 # Now using UUID v5 for .NET compatibility + + # Should increment counter - each call produces different UUID + uuid1 = self.ctx.uuid4() + uuid2 = self.ctx.uuid4() + assert uuid1 != uuid2 # Counter increments + + def test_new_guid_method(self): + """Test new_guid() alias method.""" + guid = self.ctx.new_guid() + assert isinstance(guid, uuid.UUID) + assert guid.version == 5 # Now using UUID v5 for .NET compatibility + + def test_random_string_method(self): + """Test random_string() method from DeterministicContextMixin.""" + # Test default alphabet + s1 = self.ctx.random_string(10) + assert len(s1) == 10 + assert all(c.isalnum() for c in s1) + + # Test custom alphabet + s2 = self.ctx.random_string(5, alphabet="ABC") + assert len(s2) == 5 + assert all(c in "ABC" for c in s2) + + # Test deterministic behavior + s3 = self.ctx.random_string(10) + assert s1 == s3 # Same context should produce same string + + def test_call_activity_method(self): + """Test call_activity() method.""" + activity_fn = Mock(__name__="test_activity") + + # Basic call + awaitable = self.ctx.call_activity(activity_fn, input="test_input") + + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._activity_fn is activity_fn + assert awaitable._input == "test_input" + + def test_activity_method_alias(self): + """Test activity() method alias.""" + activity_fn = Mock(__name__="test_activity") + + awaitable = self.ctx.call_activity(activity_fn, input="test_input") + + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._activity_fn is activity_fn + + def test_call_sub_orchestrator_method(self): + """Test call_sub_orchestrator() method.""" + workflow_fn = Mock(__name__="test_workflow") + + awaitable = self.ctx.call_sub_orchestrator( + workflow_fn, input="test_input", instance_id="sub-instance" + ) + + assert isinstance(awaitable, SubOrchestratorAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._workflow_fn is workflow_fn + assert awaitable._input == "test_input" + assert awaitable._instance_id == "sub-instance" + + def test_sub_orchestrator_method_alias(self): + """Test sub_orchestrator() method alias.""" + workflow_fn = Mock(__name__="test_workflow") + + awaitable = self.ctx.sub_orchestrator(workflow_fn, input="test_input") + + assert isinstance(awaitable, SubOrchestratorAwaitable) + assert awaitable._workflow_fn is workflow_fn + + def test_sleep_method(self): + """Test sleep() method.""" + # Test with float + awaitable = self.ctx.create_timer(5.0) + + assert isinstance(awaitable, SleepAwaitable) + assert awaitable._duration == 5.0 + + # Test with timedelta + duration = timedelta(minutes=1) + awaitable = self.ctx.create_timer(duration) + assert awaitable._duration is duration + + # Test with datetime + deadline = datetime(2023, 1, 1, 13, 0, 0) + awaitable = self.ctx.create_timer(deadline) + assert awaitable._duration is deadline + + def test_create_timer_method(self): + """Test create_timer() method.""" + # Test with timedelta + duration = timedelta(seconds=30) + awaitable = self.ctx.create_timer(duration) + + assert isinstance(awaitable, SleepAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._duration is duration + + def test_wait_for_external_event_method(self): + """Test wait_for_external_event() method.""" + awaitable = self.ctx.wait_for_external_event("test_event") + + assert isinstance(awaitable, ExternalEventAwaitable) + assert awaitable._ctx is self.mock_base_ctx + assert awaitable._name == "test_event" + + def test_when_all_method(self): + """Test when_all() method.""" + # Create mock awaitables + awaitable1 = Mock() + awaitable2 = Mock() + awaitables = [awaitable1, awaitable2] + + result = self.ctx.when_all(awaitables) + + assert isinstance(result, WhenAllAwaitable) + assert result._tasks_like == awaitables + + def test_when_any_method(self): + """Test when_any() method.""" + awaitable1 = Mock(spec=AwaitableBase) + awaitable1._to_task.return_value = Mock(spec=dt_task.Task) + awaitable2 = Mock(spec=AwaitableBase) + awaitable2._to_task.return_value = Mock(spec=dt_task.Task) + awaitables = [awaitable1, awaitable2] + + result = self.ctx.when_any(awaitables) + + assert isinstance(result, WhenAnyAwaitable) + assert result._originals == awaitables + + def test_with_timeout_method(self): + """Test with_timeout() method.""" + mock_awaitable = Mock() + + result = self.ctx.with_timeout(mock_awaitable, 5.0) + + assert isinstance(result, TimeoutAwaitable) + assert result._awaitable is mock_awaitable + assert result._timeout_seconds == 5.0 + + def test_set_custom_status_method(self): + """Test set_custom_status() method.""" + # Should not raise error even if base context doesn't support it + self.ctx.set_custom_status("Processing data") + + def test_continue_as_new_method(self): + """Test continue_as_new() method.""" + new_input = {"restart": True} + + # Should not raise error even if base context doesn't support it + self.ctx.continue_as_new(new_input) + + def test_get_debug_info_method(self): + """Test get_debug_info() method.""" + debug_info = self.ctx._get_info_snapshot() + + assert isinstance(debug_info, dict) + assert debug_info["instance_id"] == "test-instance-123" + assert debug_info["is_replaying"] == False + + def test_deterministic_context_mixin_integration(self): + """Test integration with DeterministicContextMixin.""" + from durabletask.deterministic import DeterministicContextMixin + + # Should be an instance of the mixin + assert isinstance(self.ctx, DeterministicContextMixin) + + # Should have all mixin methods + assert hasattr(self.ctx, "now") + assert hasattr(self.ctx, "random") + assert hasattr(self.ctx, "uuid4") + assert hasattr(self.ctx, "new_guid") + assert hasattr(self.ctx, "random_string") + + def test_context_with_string_activity_name(self): + """Test context methods with string activity/workflow names.""" + # Test with string activity name + awaitable = self.ctx.call_activity("string_activity_name", input="test") + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._activity_fn == "string_activity_name" + + # Test with string workflow name + awaitable = self.ctx.call_sub_orchestrator("string_workflow_name", input="test") + assert isinstance(awaitable, SubOrchestratorAwaitable) + assert awaitable._workflow_fn == "string_workflow_name" + + def test_context_method_parameter_validation(self): + """Test parameter validation in context methods.""" + # Test random_string with invalid parameters + with pytest.raises(ValueError): + self.ctx.random_string(-1) # Negative length + + with pytest.raises(ValueError): + self.ctx.random_string(5, alphabet="") # Empty alphabet + + def test_context_repr(self): + """Test context string representation.""" + repr_str = repr(self.ctx) + assert "AsyncWorkflowContext" in repr_str + assert "test-instance-123" in repr_str + + +class TestAsyncActivities: + """Test async activities called from AsyncWorkflowContext.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.is_replaying = False + self.mock_base_ctx.is_suspended = False + + def test_async_activity_call(self): + """Test AsyncWorkflowContext calling async activity""" + + async def async_activity(ctx: dt_task.ActivityContext, input_data: str): + await asyncio.sleep(0.001) + return input_data.upper() + + ctx = AsyncWorkflowContext(self.mock_base_ctx) + awaitable = ctx.call_activity(async_activity, input="test") + + assert isinstance(awaitable, ActivityAwaitable) + assert awaitable._activity_fn == async_activity + assert awaitable._input == "test" + + def test_async_activity_with_when_all(self): + """Test when_all with async activities""" + + async def async_activity(ctx: dt_task.ActivityContext, input_data: int): + await asyncio.sleep(0.001) + return input_data * 2 + + ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Create multiple async activity awaitables + awaitables = [ctx.call_activity(async_activity, input=i) for i in range(3)] + when_all = ctx.when_all(awaitables) + + assert isinstance(when_all, WhenAllAwaitable) + + def test_async_activity_with_when_any(self): + """Test when_any with async activities""" + + async def async_activity_fast(ctx: dt_task.ActivityContext, _): + await asyncio.sleep(0.001) + return "fast" + + async def async_activity_slow(ctx: dt_task.ActivityContext, _): + await asyncio.sleep(0.1) + return "slow" + + ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Create async activity awaitables + fast_awaitable = ctx.call_activity(async_activity_fast, input=None) + slow_awaitable = ctx.call_activity(async_activity_slow, input=None) + when_any = ctx.when_any([fast_awaitable, slow_awaitable]) + + assert isinstance(when_any, WhenAnyAwaitable) diff --git a/tests/aio/test_driver.py b/tests/aio/test_driver.py new file mode 100644 index 0000000..b8621cb --- /dev/null +++ b/tests/aio/test_driver.py @@ -0,0 +1,1148 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for driver functionality in durabletask.aio. +""" + +from typing import Any +from unittest.mock import Mock + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + AsyncWorkflowContext, + AsyncWorkflowError, + CoroutineOrchestratorRunner, + WorkflowFunction, + WorkflowValidationError, +) + +# DTPOperation deprecated: tests removed + + +class TestWorkflowFunction: + """Test WorkflowFunction protocol.""" + + def test_workflow_function_protocol(self): + """Test WorkflowFunction protocol recognition.""" + + # Valid async workflow function + async def valid_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + return "result" + + # Should be recognized as WorkflowFunction + assert isinstance(valid_workflow, WorkflowFunction) + + def test_non_async_function_protocol(self): + """Test that non-async functions are still recognized structurally.""" + + # Non-async function with correct signature + def not_async_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + return "result" + + # Should still be recognized as WorkflowFunction due to structural typing + # The actual async validation happens in CoroutineOrchestratorRunner + assert isinstance(not_async_workflow, WorkflowFunction) + + +class TestCoroutineOrchestratorRunner: + """Test CoroutineOrchestratorRunner functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + from datetime import datetime + + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance" + self.mock_base_ctx.current_utc_datetime = datetime(2025, 1, 1, 12, 0, 0) + + def test_runner_creation(self): + """Test creating a CoroutineOrchestratorRunner.""" + + async def test_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + return "result" + + runner = CoroutineOrchestratorRunner(test_workflow) + + assert runner._async_orchestrator is test_workflow + assert runner._sandbox_mode == "best_effort" + assert runner._workflow_name == "test_workflow" + + def test_runner_with_sandbox_mode(self): + """Test creating runner with sandbox mode.""" + + async def test_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + return "result" + + runner = CoroutineOrchestratorRunner(test_workflow, sandbox_mode="strict") + + assert runner._sandbox_mode == "strict" + + def test_runner_with_lambda_function(self): + """Test creating runner with lambda function.""" + + # Lambda functions must be async to be valid + def lambda_workflow(ctx, input_data): + return "result" + + # Should raise validation error for non-async lambda + with pytest.raises(WorkflowValidationError) as exc_info: + CoroutineOrchestratorRunner(lambda_workflow) + + assert "async function" in str(exc_info.value) + + def test_simple_synchronous_workflow(self): + """Test running a simple synchronous workflow.""" + + async def simple_workflow(ctx: AsyncWorkflowContext) -> str: + return "hello world" + + runner = CoroutineOrchestratorRunner(simple_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Convert to generator and run + gen = runner.to_generator(async_ctx, None) + + # Should complete immediately with StopIteration + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "hello world" + + def test_workflow_with_single_activity(self): + """Test workflow with a single activity call.""" + + async def activity_workflow(ctx: AsyncWorkflowContext, input_data: str) -> str: + result = await ctx.call_activity("test_activity", input=input_data) + return f"processed: {result}" + + # Mock the activity call + mock_task = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_activity.return_value = mock_task + + runner = CoroutineOrchestratorRunner(activity_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Convert to generator + gen = runner.to_generator(async_ctx, "test_input") + + # First yield should be the activity task + yielded_task = next(gen) + assert yielded_task is mock_task + + # Send result back + try: + gen.send("activity_result") + except StopIteration as stop: + assert stop.value == "processed: activity_result" + else: + pytest.fail("Expected StopIteration") + + def test_workflow_initialization_error(self): + """Test workflow initialization error handling.""" + + async def failing_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + raise ValueError("Initialization failed") + + runner = CoroutineOrchestratorRunner(failing_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # The error should be raised when we try to start the generator + gen = runner.to_generator(async_ctx, None) + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) # This will trigger the initialization error + + assert "Workflow failed during initialization" in str(exc_info.value) + assert exc_info.value.workflow_name == "failing_workflow" + assert exc_info.value.step == "initialization" + + def test_workflow_invalid_signature(self): + """Test workflow with invalid signature.""" + + async def invalid_workflow() -> str: # Missing ctx parameter + return "result" + + # Should raise validation error during runner creation + with pytest.raises(WorkflowValidationError) as exc_info: + CoroutineOrchestratorRunner(invalid_workflow) + + assert "at least one parameter" in str(exc_info.value) + + def test_workflow_yielding_invalid_object(self): + """Test workflow yielding invalid object.""" + + # Create a workflow that yields an invalid object + # We need to simulate this by creating a workflow that awaits something invalid + class InvalidAwaitable: + def __await__(self): + yield "invalid" # This will cause the error + return "result" + + async def invalid_yield_workflow(ctx: AsyncWorkflowContext) -> str: + result = await InvalidAwaitable() + return result + + runner = CoroutineOrchestratorRunner(invalid_yield_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + assert "awaited unsupported object type" in str(exc_info.value) + + def test_workflow_with_direct_task_yield(self): + """Test workflow with custom awaitable that yields task directly.""" + + # Create a custom awaitable that yields task directly (current approach) + class DirectTaskAwaitable: + def __init__(self, task): + self.task = task + + def __await__(self): + result = yield self.task + return f"result: {result}" + + async def direct_task_workflow(ctx: AsyncWorkflowContext) -> str: + mock_task = Mock(spec=dt_task.Task) + result = await DirectTaskAwaitable(mock_task) + return result + + runner = CoroutineOrchestratorRunner(direct_task_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should yield the underlying task + yielded_task = next(gen) + assert isinstance(yielded_task, Mock) # The mock task + + # Send result back + try: + gen.send("operation_result") + except StopIteration as stop: + assert stop.value == "result: operation_result" + + def test_workflow_exception_handling(self): + """Test workflow exception handling during execution.""" + + async def exception_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ctx.call_activity("failing_activity") + return result + + # Mock the activity call + mock_task = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_activity.return_value = mock_task + + runner = CoroutineOrchestratorRunner(exception_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First yield should be the activity task + yielded_task = next(gen) + assert yielded_task is mock_task + + # Throw an exception + test_exception = Exception("Activity failed") + try: + gen.throw(test_exception) + except StopIteration: + pytest.fail("Expected exception to propagate") + except AsyncWorkflowError as e: + # The driver wraps the original exception in AsyncWorkflowError + assert "Activity failed" in str(e) + assert e.workflow_name == "exception_workflow" + + def test_workflow_step_tracking(self): + """Test that workflow steps are tracked for error reporting.""" + + # Test that the runner correctly tracks workflow name and steps + async def multi_step_workflow(ctx: AsyncWorkflowContext) -> str: + result1 = await ctx.call_activity("step1") + result2 = await ctx.call_activity("step2") + return f"{result1}+{result2}" + + # Mock the activity calls + mock_task = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_activity.return_value = mock_task + + runner = CoroutineOrchestratorRunner(multi_step_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Verify workflow name is tracked + assert runner._workflow_name == "multi_step_workflow" + + gen = runner.to_generator(async_ctx, None) + + # First step + yielded_task = next(gen) + assert yielded_task is mock_task + + # Complete first step + yielded_task = gen.send("result1") + assert yielded_task is mock_task + + # Complete second step + try: + gen.send("result2") + except StopIteration as stop: + assert stop.value == "result1+result2" + + def test_runner_slots(self): + """Test that CoroutineOrchestratorRunner has __slots__.""" + assert hasattr(CoroutineOrchestratorRunner, "__slots__") + + def test_workflow_too_many_parameters(self): + """Test workflow with too many parameters.""" + + async def too_many_params_workflow( + ctx: AsyncWorkflowContext, input_data: Any, extra: Any + ) -> str: + return "result" + + # Should raise validation error during runner creation + with pytest.raises(WorkflowValidationError) as exc_info: + CoroutineOrchestratorRunner(too_many_params_workflow) + + assert "at most two parameters" in str(exc_info.value) + assert exc_info.value.validation_type == "function_signature" + + def test_workflow_not_callable(self): + """Test workflow that is not callable.""" + not_callable = "not a function" + + # Should raise validation error during runner creation + with pytest.raises(WorkflowValidationError) as exc_info: + CoroutineOrchestratorRunner(not_callable) + + assert "must be callable" in str(exc_info.value) + assert exc_info.value.validation_type == "function_type" + + def test_workflow_coroutine_instantiation_error(self): + """Test error during coroutine instantiation.""" + + async def problematic_workflow(ctx: AsyncWorkflowContext, input_data: Any) -> str: + return "result" + + # Mock the workflow to raise TypeError when called + runner = CoroutineOrchestratorRunner(problematic_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Replace the orchestrator with one that raises TypeError + def bad_orchestrator(*args, **kwargs): + raise TypeError("Bad instantiation") + + runner._async_orchestrator = bad_orchestrator + + gen = runner.to_generator(async_ctx, None) + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + assert "Failed to instantiate workflow coroutine" in str(exc_info.value) + assert exc_info.value.step == "initialization" + + def test_workflow_with_direct_task_awaitable(self): + """Test workflow that awaits a Task directly (tests Task branch in to_iter).""" + + async def direct_task_workflow(ctx: AsyncWorkflowContext) -> str: + # This will be caught by the to_iter function's Task branch + mock_task = Mock(spec=dt_task.Task) + # We need to make the coroutine return a Task directly, not await it + return mock_task + + runner = CoroutineOrchestratorRunner(direct_task_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should complete immediately since it's synchronous + try: + next(gen) + except StopIteration as stop: + assert isinstance(stop.value, Mock) + + def test_awaitable_completes_synchronously(self): + """Test awaitable that completes without yielding.""" + + class SyncAwaitable: + def __await__(self): + # Complete immediately without yielding + return + yield # unreachable but makes this a generator + + async def sync_awaitable_workflow(ctx: AsyncWorkflowContext) -> str: + await SyncAwaitable() + return "completed" + + runner = CoroutineOrchestratorRunner(sync_awaitable_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should complete without yielding any tasks + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "completed" + + def test_awaitable_yields_non_task(self): + """Test awaitable that yields non-Task object during execution.""" + + class BadAwaitable: + def __await__(self): + yield "not a task" # This should trigger the non-Task error + return "result" + + async def bad_awaitable_workflow(ctx: AsyncWorkflowContext) -> str: + result = await BadAwaitable() + return result + + runner = CoroutineOrchestratorRunner(bad_awaitable_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + assert "awaited unsupported object type" in str(exc_info.value) + assert exc_info.value.step == "awaitable_conversion" + + def test_awaitable_exception_handling_with_completion(self): + """Test exception handling where awaitable completes after exception.""" + + class ExceptionThenCompleteAwaitable: + def __init__(self): + self.threw = False + + def __await__(self): + task = Mock(spec=dt_task.Task) + try: + result = yield task + return f"normal: {result}" + except Exception as e: + self.threw = True + return f"exception handled: {e}" + + async def exception_handling_workflow(ctx: AsyncWorkflowContext) -> str: + awaitable = ExceptionThenCompleteAwaitable() + result = await awaitable + return result + + runner = CoroutineOrchestratorRunner(exception_handling_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Get the task + _ = next(gen) + + # Throw an exception + test_exception = Exception("test error") + try: + gen.throw(test_exception) + except StopIteration as stop: + assert "exception handled: test error" in stop.value + + def test_awaitable_exception_propagation(self): + """Test exception propagation through awaitable.""" + + class ExceptionPropagatingAwaitable: + def __await__(self): + task = Mock(spec=dt_task.Task) + result = yield task + return result + + async def exception_propagation_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ExceptionPropagatingAwaitable() + return result + + runner = CoroutineOrchestratorRunner(exception_propagation_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Get the task + _ = next(gen) + + # Throw an exception that should propagate to the coroutine + test_exception = Exception("propagated error") + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.throw(test_exception) + + assert "propagated error" in str(exc_info.value) + assert exc_info.value.step == "execution" + + def test_multi_yield_awaitable(self): + """Test awaitable that yields multiple tasks.""" + + class MultiYieldAwaitable: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + task2 = Mock(spec=dt_task.Task) + result1 = yield task1 + result2 = yield task2 + return f"{result1}+{result2}" + + async def multi_yield_workflow(ctx: AsyncWorkflowContext) -> str: + result = await MultiYieldAwaitable() + return result + + runner = CoroutineOrchestratorRunner(multi_yield_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task + task1 = next(gen) + assert isinstance(task1, Mock) + + # Second task + task2 = gen.send("result1") + assert isinstance(task2, Mock) + + # Final result + try: + gen.send("result2") + except StopIteration as stop: + assert stop.value == "result1+result2" + + def test_multi_yield_awaitable_with_non_task(self): + """Test multi-yield awaitable that yields non-Task.""" + + class BadMultiYieldAwaitable: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + result1 = yield task1 + yield "not a task" # This should cause error + return result1 + + async def bad_multi_yield_workflow(ctx: AsyncWorkflowContext) -> str: + result = await BadMultiYieldAwaitable() + return result + + runner = CoroutineOrchestratorRunner(bad_multi_yield_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task + _ = next(gen) + + # Send result, should get error on second yield + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.send("result1") + + assert "awaited unsupported object type" in str(exc_info.value) + + def test_multi_yield_awaitable_exception_in_continuation(self): + """Test exception handling in multi-yield awaitable continuation.""" + + class ExceptionInContinuationAwaitable: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + _ = yield task1 + # This will cause an exception when we try to continue + raise ValueError("continuation error") + + async def exception_continuation_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ExceptionInContinuationAwaitable() + return result + + runner = CoroutineOrchestratorRunner(exception_continuation_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task + _ = next(gen) + + # Send result, should get error in continuation + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.send("result1") + + assert "continuation error" in str(exc_info.value) + + def test_runner_properties(self): + """Test runner property getters.""" + + async def test_workflow(ctx: AsyncWorkflowContext) -> str: + return "result" + + runner = CoroutineOrchestratorRunner( + test_workflow, sandbox_mode="strict", workflow_name="custom_name" + ) + + assert runner.workflow_name == "custom_name" + assert runner.sandbox_mode == "strict" + + def test_runner_with_custom_workflow_name(self): + """Test runner with custom workflow name.""" + + async def test_workflow(ctx: AsyncWorkflowContext) -> str: + return "result" + + runner = CoroutineOrchestratorRunner(test_workflow, workflow_name="custom_workflow") + + assert runner._workflow_name == "custom_workflow" + + def test_runner_with_function_without_name(self): + """Test runner with function that has no __name__ attribute.""" + + async def test_workflow(ctx: AsyncWorkflowContext) -> str: + return "result" + + # Mock getattr to return None for __name__ + from unittest.mock import patch + + with patch("durabletask.aio.driver.getattr") as mock_getattr: + + def side_effect(obj, attr, default=None): + if attr == "__name__": + return None # Simulate missing __name__ + return getattr(obj, attr, default) + + mock_getattr.side_effect = side_effect + + runner = CoroutineOrchestratorRunner(test_workflow) + assert runner._workflow_name == "unknown" + + def test_awaitable_that_yields_task_then_non_task(self): + """Test awaitable that first yields a Task, then yields non-Task (hits line 269-277).""" + + class TaskThenNonTaskAwaitable: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + result1 = yield task1 + # This second yield should trigger the non-Task error in the while loop + yield "not a task" + return result1 + + async def task_then_non_task_workflow(ctx: AsyncWorkflowContext) -> str: + result = await TaskThenNonTaskAwaitable() + return result + + runner = CoroutineOrchestratorRunner(task_then_non_task_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task should be yielded + task1 = next(gen) + assert isinstance(task1, Mock) + + # Send result, should get error on second yield + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.send("result1") + + assert "awaited unsupported object type" in str(exc_info.value) + assert exc_info.value.step == "awaitable_conversion" + + def test_workflow_with_input_parameter(self): + """Test workflow that accepts input parameter.""" + + async def input_workflow(ctx: AsyncWorkflowContext, input_data: dict) -> str: + name = input_data.get("name", "world") + return f"Hello, {name}!" + + runner = CoroutineOrchestratorRunner(input_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, {"name": "Alice"}) + + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "Hello, Alice!" + + def test_workflow_without_input_parameter(self): + """Test workflow that doesn't accept input parameter.""" + + async def no_input_workflow(ctx: AsyncWorkflowContext) -> str: + return "No input needed" + + runner = CoroutineOrchestratorRunner(no_input_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Should work with None input + gen = runner.to_generator(async_ctx, None) + + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "No input needed" + + # Should also work with actual input (will be ignored) + gen = runner.to_generator(async_ctx, {"ignored": "data"}) + + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "No input needed" + + def test_sandbox_mode_execution_with_activity(self): + """Test workflow execution with sandbox mode enabled.""" + + async def sandbox_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ctx.call_activity("test_activity", input="test") + return f"Activity result: {result}" + + # Mock the activity call + mock_task = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_activity.return_value = mock_task + + runner = CoroutineOrchestratorRunner(sandbox_workflow, sandbox_mode="best_effort") + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should yield a task from the activity call + task = next(gen) + assert task is mock_task + + # Send result back + with pytest.raises(StopIteration) as exc_info: + gen.send("activity_result") + + assert exc_info.value.value == "Activity result: activity_result" + + def test_sandbox_mode_execution_with_exception(self): + """Test workflow exception handling with sandbox mode enabled.""" + + async def failing_sandbox_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ctx.call_activity("test_activity", input="test") + if result == "bad": + raise ValueError("Bad result") + return result + + # Mock the activity call + mock_task = Mock(spec=dt_task.Task) + self.mock_base_ctx.call_activity.return_value = mock_task + + runner = CoroutineOrchestratorRunner(failing_sandbox_workflow, sandbox_mode="best_effort") + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should yield a task from the activity call + task = next(gen) + assert task is mock_task + + # Send bad result that triggers exception + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.send("bad") + + assert "Bad result" in str(exc_info.value) + assert exc_info.value.step == "execution" + + def test_sandbox_mode_synchronous_completion(self): + """Test synchronous workflow completion with sandbox mode.""" + + async def sync_sandbox_workflow(ctx: AsyncWorkflowContext) -> str: + return "sync_result" + + runner = CoroutineOrchestratorRunner(sync_sandbox_workflow, sandbox_mode="best_effort") + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should complete immediately + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "sync_result" + + def test_custom_awaitable_with_await_method(self): + """Test custom awaitable class with __await__ method.""" + + class CustomAwaitable: + def __init__(self, value): + self.value = value + + def __await__(self): + task = Mock(spec=dt_task.Task) + result = yield task + return f"{self.value}: {result}" + + async def custom_awaitable_workflow(ctx: AsyncWorkflowContext) -> str: + result = await CustomAwaitable("custom") + return result + + runner = CoroutineOrchestratorRunner(custom_awaitable_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should yield the task from the custom awaitable + task = next(gen) + assert isinstance(task, Mock) + + # Send result + with pytest.raises(StopIteration) as exc_info: + gen.send("task_result") + + assert exc_info.value.value == "custom: task_result" + + def test_synchronous_awaitable_then_exception(self): + """Test exception after synchronous awaitable completion.""" + + class SyncAwaitable: + def __await__(self): + return + yield # unreachable but makes this a generator + + async def sync_then_fail_workflow(ctx: AsyncWorkflowContext) -> str: + await SyncAwaitable() + raise ValueError("Error after sync awaitable") + + runner = CoroutineOrchestratorRunner(sync_then_fail_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should raise AsyncWorkflowError wrapping the ValueError + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + assert "Error after sync awaitable" in str(exc_info.value) + # Error happens during initialization since it's in the first send(None) + assert exc_info.value.step in ("initialization", "execution") + + def test_non_task_object_at_request_level(self): + """Test that non-Task objects yielded directly are caught.""" + + class BadAwaitable: + def __await__(self): + # Yield something that's not a Task + yield {"not": "a task"} + return "result" + + async def bad_request_workflow(ctx: AsyncWorkflowContext) -> str: + result = await BadAwaitable() + return result + + runner = CoroutineOrchestratorRunner(bad_request_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should raise AsyncWorkflowError about non-Task object + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen) + + assert "awaited unsupported object type" in str(exc_info.value) + + def test_multi_yield_awaitable_with_exception_in_middle(self): + """Test exception handling during multi-yield awaitable.""" + + class MultiYieldWithException: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + task2 = Mock(spec=dt_task.Task) + result1 = yield task1 + # Exception might be thrown here + result2 = yield task2 + return f"{result1}+{result2}" + + async def multi_yield_exception_workflow(ctx: AsyncWorkflowContext) -> str: + result = await MultiYieldWithException() + return result + + runner = CoroutineOrchestratorRunner(multi_yield_exception_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Get first task + task1 = next(gen) + assert isinstance(task1, Mock) + + # Send result for first task + task2 = gen.send("result1") + assert isinstance(task2, Mock) + + # Throw exception on second task + test_exception = RuntimeError("exception during multi-yield") + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.throw(test_exception) + + assert "exception during multi-yield" in str(exc_info.value) + assert exc_info.value.step == "execution" + + def test_multi_yield_awaitable_exception_handled_then_rethrow(self): + """Test exception handling where awaitable catches then re-throws.""" + + class ExceptionRethrower: + def __await__(self): + task = Mock(spec=dt_task.Task) + try: + result = yield task + return result + except Exception as e: + # Catch and re-throw as different exception + raise ValueError(f"Transformed: {e}") from e + + async def rethrow_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ExceptionRethrower() + return result + + runner = CoroutineOrchestratorRunner(rethrow_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Get task + task = next(gen) + assert isinstance(task, Mock) + + # Throw exception + original_exception = RuntimeError("original error") + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.throw(original_exception) + + assert "Transformed: original error" in str(exc_info.value) + assert exc_info.value.step == "execution" + + def test_multi_yield_consecutive_tasks(self): + """Test awaitable yielding multiple tasks consecutively.""" + + class ConsecutiveTaskYielder: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + task2 = Mock(spec=dt_task.Task) + task3 = Mock(spec=dt_task.Task) + result1 = yield task1 + result2 = yield task2 + result3 = yield task3 + return f"{result1}+{result2}+{result3}" + + async def consecutive_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ConsecutiveTaskYielder() + return result + + runner = CoroutineOrchestratorRunner(consecutive_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task + task1 = next(gen) + assert isinstance(task1, Mock) + + # Second task + task2 = gen.send("r1") + assert isinstance(task2, Mock) + + # Third task + task3 = gen.send("r2") + assert isinstance(task3, Mock) + + # Final result + with pytest.raises(StopIteration) as exc_info: + gen.send("r3") + + assert exc_info.value.value == "r1+r2+r3" + + def test_multi_yield_with_non_task_in_sequence(self): + """Test multi-yield that yields non-Task in the sequence.""" + + class BadMultiYield: + def __await__(self): + task1 = Mock(spec=dt_task.Task) + result1 = yield task1 + # Second yield is not a Task + result2 = yield "not a task" + return f"{result1}+{result2}" + + async def bad_multi_yield_workflow(ctx: AsyncWorkflowContext) -> str: + result = await BadMultiYield() + return result + + runner = CoroutineOrchestratorRunner(bad_multi_yield_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task succeeds + task1 = next(gen) + assert isinstance(task1, Mock) + + # Second yield should fail with non-Task error + with pytest.raises(AsyncWorkflowError) as exc_info: + gen.send("result1") + + # Error message varies based on where the non-Task is detected + assert "non-Task object" in str(exc_info.value) or "unsupported object type" in str( + exc_info.value + ) + assert exc_info.value.step in ("execution", "awaitable_conversion") + + def test_awaitable_exception_completion_with_sandbox(self): + """Test exception handling with sandbox mode enabled.""" + + class ExceptionHandlingAwaitable: + def __await__(self): + task = Mock(spec=dt_task.Task) + try: + result = yield task + return f"normal: {result}" + except Exception as e: + return f"handled: {e}" + + async def sandbox_exception_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ExceptionHandlingAwaitable() + return result + + runner = CoroutineOrchestratorRunner(sandbox_exception_workflow, sandbox_mode="best_effort") + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Get task + task = next(gen) + assert isinstance(task, Mock) + + # Throw exception + test_exception = ValueError("test error") + with pytest.raises(StopIteration) as exc_info: + gen.throw(test_exception) + + assert "handled: test error" in exc_info.value.value + + def test_multiple_synchronous_awaitables_with_sandbox(self): + """Test multiple synchronous awaitables in sequence with sandbox mode.""" + + class SyncAwaitable: + def __init__(self, value): + self.value = value + + def __await__(self): + # Complete immediately without yielding + return self.value + yield # unreachable but makes this a generator + + async def multi_sync_workflow(ctx: AsyncWorkflowContext) -> str: + result1 = await SyncAwaitable("first") + result2 = await SyncAwaitable("second") + result3 = await SyncAwaitable("third") + return f"{result1}-{result2}-{result3}" + + runner = CoroutineOrchestratorRunner(multi_sync_workflow, sandbox_mode="best_effort") + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # Should complete without yielding any tasks + with pytest.raises(StopIteration) as exc_info: + next(gen) + + assert exc_info.value.value == "first-second-third" + + def test_awaitable_yielding_many_tasks(self): + """Test awaitable that yields 5+ tasks to exercise inner loop.""" + + class ManyTaskYielder: + def __await__(self): + # Yield 6 tasks consecutively + results = [] + for i in range(6): + task = Mock(spec=dt_task.Task) + result = yield task + results.append(str(result)) + return "+".join(results) + + async def many_tasks_workflow(ctx: AsyncWorkflowContext) -> str: + result = await ManyTaskYielder() + return result + + runner = CoroutineOrchestratorRunner(many_tasks_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task is yielded from the outer loop + task = next(gen) + assert isinstance(task, Mock) + + # Send result and continue - remaining tasks are in the inner while loop + for i in range(1, 6): + task = gen.send(f"r{i}") + assert isinstance(task, Mock) + + # Send last result - workflow should complete + with pytest.raises(StopIteration) as exc_info: + gen.send("r6") + + assert exc_info.value.value == "r1+r2+r3+r4+r5+r6" + + def test_awaitable_burst_yielding_tasks(self): + """Test awaitable that yields multiple tasks consecutively without waiting (inner while loop).""" + + class BurstTaskYielder: + """Yields multiple tasks in rapid succession to exercise inner while loop at lines 270-278.""" + + def __await__(self): + # Yield 5 tasks consecutively - each yield statement is executed immediately + # This pattern exercises the inner while loop that processes consecutive task yields + task1 = Mock(spec=dt_task.Task) + task2 = Mock(spec=dt_task.Task) + task3 = Mock(spec=dt_task.Task) + task4 = Mock(spec=dt_task.Task) + task5 = Mock(spec=dt_task.Task) + + # All these yields happen in rapid succession + r1 = yield task1 + r2 = yield task2 + r3 = yield task3 + r4 = yield task4 + r5 = yield task5 + + return f"{r1}-{r2}-{r3}-{r4}-{r5}" + + async def burst_workflow(ctx: AsyncWorkflowContext) -> str: + result = await BurstTaskYielder() + return result + + runner = CoroutineOrchestratorRunner(burst_workflow) + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + gen = runner.to_generator(async_ctx, None) + + # First task is yielded from outer loop (line 228) + task1 = next(gen) + assert isinstance(task1, Mock) + + # When we send result for task1, the awaitable immediately yields task2, task3, task4, task5 + # This enters the inner while loop at line 270 to process consecutive yields + task2 = gen.send("result1") + assert isinstance(task2, Mock) + + # Continue through the burst - all handled by inner while loop (line 270-278) + task3 = gen.send("result2") + assert isinstance(task3, Mock) + + task4 = gen.send("result3") + assert isinstance(task4, Mock) + + task5 = gen.send("result4") + assert isinstance(task5, Mock) + + # Final result completes the awaitable + with pytest.raises(StopIteration) as exc_info: + gen.send("result5") + + assert exc_info.value.value == "result1-result2-result3-result4-result5" diff --git a/tests/aio/test_e2e.py b/tests/aio/test_e2e.py new file mode 100644 index 0000000..afe1837 --- /dev/null +++ b/tests/aio/test_e2e.py @@ -0,0 +1,1123 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +End-to-end tests for durabletask.aio package. + +These tests require a running Dapr sidecar or DurableTask-Go emulator. +They test actual workflow execution against a real runtime. + +To run these tests: +1. Start Dapr sidecar: dapr run --app-id test-app --dapr-grpc-port 50001 +2. Or start DurableTask-Go emulator on localhost:4001 +3. Run: pytest tests/aio/test_e2e.py -m e2e +""" + +import asyncio +import json +import os +import time +from datetime import datetime + +import pytest + +from durabletask import client, task, worker +from durabletask.aio import AsyncWorkflowContext +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + +# Skip all tests in this module unless explicitly running e2e tests +pytestmark = pytest.mark.e2e + + +def _deserialize_result(result): + """Parse serialized_output as JSON and return the resulting object. + + Returns None if there is no output. + """ + if result.serialized_output is None: + return None + return json.loads(result.serialized_output) + + +def _log_orchestration_progress( + hub_client: TaskHubGrpcClient, instance_id: str, max_seconds: int = 60 +) -> None: + """Helper to log orchestration status every second up to max_seconds.""" + deadline = time.time() + max_seconds + last_status = None + while time.time() < deadline: + try: + st = hub_client.get_orchestration_state(instance_id, fetch_payloads=True) + if st is None: + print("[async e2e] state: None") + else: + status_name = st.runtime_status.name + if status_name != last_status: + print(f"[async e2e] state: {status_name}") + last_status = status_name + if status_name in ("COMPLETED", "FAILED", "TERMINATED"): + print("[async e2e] reached terminal state during polling") + break + except Exception as e: + print(f"[async e2e] polling error: {e}") + time.sleep(1) + + +class TestAsyncWorkflowE2E: + """End-to-end tests for async workflows with real runtime.""" + + @classmethod + def setup_class(cls): + """Set up test class with worker and client.""" + # Use environment variable or default to localhost:4001 (DurableTask-Go) + grpc_endpoint = os.getenv("DURABLETASK_GRPC_ENDPOINT", "localhost:4001") + # Skip if runtime not available + if not is_runtime_available(grpc_endpoint): + import pytest as _pytest + + _pytest.skip(f"DurableTask runtime not available at {grpc_endpoint}") + + cls.worker = TaskHubGrpcWorker(host_address=grpc_endpoint) + cls.client = TaskHubGrpcClient(host_address=grpc_endpoint) + + # Register test activities and workflows + cls._register_test_functions() + + time.sleep(2) + + # Start worker and wait for ready + cls.worker.start() + + @classmethod + def teardown_class(cls): + """Clean up worker and client.""" + try: + if hasattr(cls.worker, "stop"): + cls.worker.stop() + except Exception: + pass + + @classmethod + def _register_test_functions(cls): + """Register test activities and workflows.""" + + # Test activity + def test_activity(ctx, input_data: str) -> str: + print(f"[E2E] test_activity input={input_data}") + return f"Activity processed: {input_data}" + + cls.worker._registry.add_named_activity("test_activity", test_activity) + cls.test_activity = test_activity + + # Test async workflow + @cls.worker.add_orchestrator + async def simple_async_workflow(ctx: AsyncWorkflowContext, input_data: str) -> str: + result = await ctx.call_activity(test_activity, input=input_data) + return f"Workflow result: {result}" + + cls.simple_async_workflow = simple_async_workflow + + # Multi-step async workflow + @cls.worker.add_async_orchestrator + async def multi_step_async_workflow(ctx: AsyncWorkflowContext, steps: int) -> dict: + results = [] + for i in range(steps): + result = await ctx.call_activity(test_activity, input=f"step_{i}") + results.append(result) + + return { + "instance_id": ctx.instance_id, + "steps_completed": len(results), + "results": results, + "timestamp": ctx.now().isoformat(), + } + + cls.multi_step_async_workflow = multi_step_async_workflow + + # Parallel workflow + @cls.worker.add_async_orchestrator + async def parallel_async_workflow(ctx: AsyncWorkflowContext, parallel_count: int) -> list: + tasks = [] + for i in range(parallel_count): + task = ctx.call_activity(test_activity, input=f"parallel_{i}") + tasks.append(task) + + results = await ctx.when_all(tasks) + return results + + cls.parallel_async_workflow = parallel_async_workflow + + @cls.worker.add_async_orchestrator(sandbox_mode="best_effort") + async def sandbox_when_all_workflow( + ctx: AsyncWorkflowContext, parallel_count: int + ) -> list[str]: + tasks = [ + ctx.call_activity(test_activity, input=f"sandbox_{i}") + for i in range(parallel_count) + ] + + results = await asyncio.gather(*tasks) + + return list(results) + + cls.sandbox_when_all_workflow = sandbox_when_all_workflow + + # when_any with activities (register early) + @cls.worker.add_async_orchestrator + async def when_any_activities(ctx: AsyncWorkflowContext, _) -> dict: + t1 = ctx.call_activity(test_activity, input="a1") + t2 = ctx.call_activity(test_activity, input="a2") + idx, result = await ctx.when_any([t1, t2]) + return {"result": result} + + cls.when_any_activities = when_any_activities + + # when_any mixing activity and timer (register early) + @cls.worker.add_async_orchestrator + async def when_any_with_timer(ctx: AsyncWorkflowContext, _) -> dict: + t_activity = ctx.call_activity(test_activity, input="wa") + t_timer = ctx.create_timer(0.1) + idx, res = await ctx.when_any([t_activity, t_timer]) + return {"index": idx, "has_result": res is not None} + + cls.when_any_with_timer = when_any_with_timer + + # Timer workflow + @cls.worker.add_async_orchestrator + async def timer_async_workflow(ctx: AsyncWorkflowContext, delay_seconds: float) -> dict: + start_time = ctx.now() + + # Wait for specified delay + await ctx.create_timer(delay_seconds) + + end_time = ctx.now() + + return { + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat(), + "delay_seconds": delay_seconds, + } + + cls.timer_async_workflow = timer_async_workflow + + # Sub-orchestrator workflow + @cls.worker.add_async_orchestrator + async def child_async_workflow(ctx: AsyncWorkflowContext, input_data: str) -> str: + result = await ctx.call_activity(test_activity, input=input_data) + return f"Child: {result}" + + cls.child_async_workflow = child_async_workflow + + @cls.worker.add_async_orchestrator + async def parent_async_workflow(ctx: AsyncWorkflowContext, input_data: str) -> dict: + # Call child workflow + child_result = await ctx.call_sub_orchestrator( + child_async_workflow, input=input_data, instance_id=f"{ctx.instance_id}_child" + ) + + # Process child result + final_result = await ctx.call_activity(test_activity, input=child_result) + + return { + "parent_instance": ctx.instance_id, + "child_result": child_result, + "final_result": final_result, + } + + cls.parent_async_workflow = parent_async_workflow + + # Additional orchestrators for specific tests + @cls.worker.add_async_orchestrator + async def suspend_resume_workflow(ctx: AsyncWorkflowContext, _): + val = await ctx.wait_for_external_event("x") + return val + + cls.suspend_resume_workflow = suspend_resume_workflow + + @cls.worker.add_async_orchestrator + async def sub_orch_child(ctx: AsyncWorkflowContext, x: int): + return x + 1 + + cls.sub_orch_child = sub_orch_child + + @cls.worker.add_async_orchestrator + async def sub_orch_parent(ctx: AsyncWorkflowContext, x: int): + y = await ctx.call_sub_orchestrator(sub_orch_child, input=x) + return y * 2 + + cls.sub_orch_parent = sub_orch_parent + + # Minimal workflow for debugging - no activities + @cls.worker.add_orchestrator + async def minimal_workflow(ctx: AsyncWorkflowContext, input_data: str) -> str: + return f"Minimal result: {input_data}" + + cls.minimal_workflow = minimal_workflow + + # Determinism test workflow + @cls.worker.add_orchestrator + async def deterministic_test_workflow(ctx: AsyncWorkflowContext, input_data: str) -> dict: + random_val = ctx.random().random() + uuid_val = str(ctx.uuid4()) + string_val = ctx.random_string(10) + activity_result = await ctx.call_activity(test_activity, input=input_data) + return { + "random": random_val, + "uuid": uuid_val, + "string": string_val, + "activity": activity_result, + "timestamp": ctx.now().isoformat(), + } + + cls.deterministic_test_workflow = deterministic_test_workflow + + # Error handling workflow + def failing_activity(ctx, input_data: str) -> str: + raise ValueError(f"Activity failed with input: {input_data}") + + cls.worker.add_activity(failing_activity) + + @cls.worker.add_orchestrator + async def error_handling_workflow(ctx: AsyncWorkflowContext, input_data: str) -> dict: + try: + result = await ctx.call_activity(failing_activity, input=input_data) + return {"status": "success", "result": result} + except Exception as e: + return {"status": "error", "error": str(e)} + + cls.error_handling_workflow = error_handling_workflow + + # External event workflow + @cls.worker.add_orchestrator + async def external_event_workflow(ctx: AsyncWorkflowContext, event_name: str) -> dict: + initial_result = await ctx.call_activity(test_activity, input="initial") + event_data = await ctx.wait_for_external_event(event_name) + final_result = await ctx.call_activity(test_activity, input=f"event_{event_data}") + return {"initial": initial_result, "event_data": event_data, "final": final_result} + + cls.external_event_workflow = external_event_workflow + + # (moved earlier) when_any registrations + + # when_any between external event and timeout + @cls.worker.add_async_orchestrator + async def when_any_event_or_timeout(ctx: AsyncWorkflowContext, event_name: str) -> dict: + print(f"[E2E] when_any_event_or_timeout start id={ctx.instance_id} evt={event_name}") + evt = ctx.wait_for_external_event(event_name) + timeout = ctx.create_timer(5.0) + idx, result = await ctx.when_any([evt, timeout]) + if idx == 0: + print(f"[E2E] when_any_event_or_timeout winner=event val={result}") + return {"winner": "event", "val": result} + print("[E2E] when_any_event_or_timeout winner=timeout") + return {"winner": "timeout"} + + cls.when_any_event_or_timeout = when_any_event_or_timeout + + # Debug: list registered orchestrators + try: + reg = getattr(cls.worker, "_registry", None) + if reg is not None: + keys = list(getattr(reg, "orchestrators", {}).keys()) + print(f"[E2E] registered orchestrators: {keys}") + except Exception: + pass + + def setup_method(self): + """Set up each test method.""" + # Worker is started in setup_class; nothing to do per-test + pass + + @pytest.mark.e2e + def test_async_suspend_and_resume_dt_e2e(self): + """Async suspend/resume using class-level worker/client (more stable).""" + from durabletask import client as dt_client + + # Schedule and wait for RUNNING + orch_id = self.client.schedule_new_orchestration(type(self).suspend_resume_workflow) + st = self.client.wait_for_orchestration_start(orch_id, timeout=30) + assert st is not None and st.runtime_status == dt_client.OrchestrationStatus.RUNNING + + # Suspend + self.client.suspend_orchestration(orch_id) + # Wait until SUSPENDED (poll) + for _ in range(100): + st = self.client.get_orchestration_state(orch_id) + assert st is not None + if st.runtime_status == dt_client.OrchestrationStatus.SUSPENDED: + break + time.sleep(0.1) + + # Raise event then resume + self.client.raise_orchestration_event(orch_id, "x", data=42) + self.client.resume_orchestration(orch_id) + + # Prefer server-side wait, then log/poll fallback + try: + st = self.client.wait_for_orchestration_completion(orch_id, timeout=60) + except TimeoutError: + _log_orchestration_progress(self.client, orch_id, max_seconds=30) + st = self.client.get_orchestration_state(orch_id, fetch_payloads=True) + + assert st is not None + assert st.runtime_status == dt_client.OrchestrationStatus.COMPLETED + assert st.serialized_output == "42" + + @pytest.mark.e2e + def test_async_sub_orchestrator_dt_e2e(self): + """Async sub-orchestrator end-to-end with stable class-level worker/client.""" + from durabletask import client as dt_client + + orch_id = self.client.schedule_new_orchestration(type(self).sub_orch_parent, input=3) + + try: + st = self.client.wait_for_orchestration_completion(orch_id, timeout=60) + except TimeoutError: + _log_orchestration_progress(self.client, orch_id, max_seconds=30) + st = self.client.get_orchestration_state(orch_id, fetch_payloads=True) + + assert st is not None + assert st.runtime_status == dt_client.OrchestrationStatus.COMPLETED + assert st.failure_details is None + assert st.serialized_output == "8" + + @pytest.mark.e2e + def test_simple_async_workflow_e2e(self): + """Test simple async workflow end-to-end.""" + # Use class worker/client which are already started + instance_id = self.client.schedule_new_orchestration( + type(self).simple_async_workflow, input="test_input" + ) + print(f"[async e2e] scheduled instance_id={instance_id}") + # Quick initial probe + try: + st = self.client.get_orchestration_state(instance_id, fetch_payloads=True) + print(f"[async e2e] initial state: {getattr(st, 'runtime_status', None)}") + except Exception as e: + print(f"[async e2e] initial get_orchestration_state failed: {e}") + + # Prefer server-side wait; on timeout, log progress via polling without extending total time + start_ts = time.time() + try: + state = self.client.wait_for_orchestration_completion(instance_id, timeout=60) + except TimeoutError: + elapsed = time.time() - start_ts + remaining = max(0, int(60 - elapsed)) + print( + f"[async e2e] server-side wait timed out after {elapsed:.1f}s; polling for remaining {remaining}s" + ) + if remaining > 0: + _log_orchestration_progress(self.client, instance_id, max_seconds=remaining) + # Get final state once more before asserting + state = self.client.get_orchestration_state(instance_id, fetch_payloads=True) + assert state is not None + assert state.runtime_status.name == "COMPLETED" + assert "Activity processed: test_input" in (state.serialized_output or "") + + @pytest.mark.asyncio + async def test_multi_step_async_workflow_e2e(self): + """Test multi-step async workflow end-to-end.""" + instance_id = f"test_multi_step_{int(time.time())}" + + # Start workflow + self.client.schedule_new_orchestration( + type(self).multi_step_async_workflow, input=3, instance_id=instance_id + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + assert result_data["steps_completed"] == 3 + assert len(result_data["results"]) == 3 + assert result_data["instance_id"] == instance_id + + @pytest.mark.asyncio + async def test_parallel_async_workflow_e2e(self): + """Test parallel async workflow end-to-end.""" + instance_id = f"test_parallel_{int(time.time())}" + + # Start workflow + self.client.schedule_new_orchestration( + type(self).parallel_async_workflow, input=3, instance_id=instance_id + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + # Should have 3 parallel results + assert len(result_data) == 3 + for i, res in enumerate(result_data): + assert f"parallel_{i}" in res + + @pytest.mark.asyncio + async def test_sandbox_when_all_workflow_e2e(self): + """Test sandboxed gather bridging to when_all end-to-end.""" + instance_id = f"test_sandbox_when_all_{int(time.time())}" + parallel_count = 3 + + self.client.schedule_new_orchestration( + type(self).sandbox_when_all_workflow, + input=parallel_count, + instance_id=instance_id, + ) + + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + assert isinstance(result_data, list) + assert len(result_data) == parallel_count + for i, res in enumerate(result_data): + assert f"sandbox_{i}" in res + + @pytest.mark.asyncio + async def test_timer_async_workflow_e2e(self): + """Test timer async workflow end-to-end.""" + instance_id = f"test_timer_{int(time.time())}" + delay_seconds = 2.0 + + # Start workflow + self.client.schedule_new_orchestration( + type(self).timer_async_workflow, input=delay_seconds, instance_id=instance_id + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + assert result_data["delay_seconds"] == delay_seconds + # Validate using orchestrator timestamps to avoid wall-clock skew + start_iso = result_data.get("start_time") + end_iso = result_data.get("end_time") + if isinstance(start_iso, str) and isinstance(end_iso, str): + start_dt = datetime.fromisoformat(start_iso) + end_dt = datetime.fromisoformat(end_iso) + elapsed = (end_dt - start_dt).total_seconds() + # Allow jitter from backend scheduling and timestamp rounding + assert elapsed >= (delay_seconds - 1.0) + + @pytest.mark.asyncio + async def test_sub_orchestrator_async_workflow_e2e(self): + """Test sub-orchestrator async workflow end-to-end.""" + instance_id = f"test_sub_orch_{int(time.time())}" + + # Start parent workflow + self.client.schedule_new_orchestration( + type(self).parent_async_workflow, input="test_data", instance_id=instance_id + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + assert result_data["parent_instance"] == instance_id + assert "Child: Activity processed: test_data" in result_data["child_result"] + assert "Activity processed: Child:" in result_data["final_result"] + + @pytest.mark.asyncio + async def test_workflow_determinism_e2e(self): + """Test that async workflows are deterministic during replay.""" + instance_id = f"test_determinism_{int(time.time())}" + # Start pre-registered workflow + self.client.schedule_new_orchestration( + type(self).deterministic_test_workflow, + input="determinism_test", + instance_id=instance_id, + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + # Verify deterministic values are present + assert "random" in result_data + assert "uuid" in result_data + assert "string" in result_data + assert "Activity processed: determinism_test" in result_data["activity"] + + # The values should be deterministic based on instance_id and orchestration time + # We can't easily test replay here, but the workflow should complete successfully + + @pytest.mark.asyncio + async def test_when_any_activities_e2e(self): + instance_id = f"test_when_any_acts_{int(time.time())}" + self.client.schedule_new_orchestration( + type(self).when_any_activities, input=None, instance_id=instance_id + ) + # Ensure the sidecar has started processing this orchestration + try: + st = self.client.wait_for_orchestration_start(instance_id, timeout=30) + except Exception: + st = None + assert st is not None + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + assert result is not None + if result.failure_details: + print( + "when_any_activities failure:", + result.failure_details.error_type, + result.failure_details.message, + ) + assert False, "when_any_activities failed" + data = _deserialize_result(result) + assert isinstance(data, dict) + assert "Activity processed:" in data.get("result", "") + + @pytest.mark.asyncio + async def test_when_any_with_timer_e2e(self): + instance_id = f"test_when_any_timer_{int(time.time())}" + self.client.schedule_new_orchestration( + type(self).when_any_with_timer, input=None, instance_id=instance_id + ) + try: + _ = self.client.wait_for_orchestration_start(instance_id, timeout=30) + except Exception: + pass + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + assert result is not None + data = _deserialize_result(result) + assert isinstance(data, dict) + assert data.get("index") in (0, 1) + assert isinstance(data.get("has_result"), bool) + + @pytest.mark.asyncio + async def test_when_any_event_or_timeout_e2e(self): + instance_id = f"test_when_any_event_{int(time.time())}" + event_name = "evt" + self.client.schedule_new_orchestration( + type(self).when_any_event_or_timeout, input=event_name, instance_id=instance_id + ) + try: + _ = self.client.wait_for_orchestration_start(instance_id, timeout=30) + except Exception: + pass + # Raise the event shortly after to ensure event wins + time.sleep(0.5) + self.client.raise_orchestration_event(instance_id, event_name, data="hello") + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + assert result is not None + if result.failure_details: + print( + "when_any_event_or_timeout failure:", + result.failure_details.error_type, + result.failure_details.message, + ) + assert False, "when_any_event_or_timeout failed" + data = _deserialize_result(result) + assert data.get("winner") == "event" + assert data.get("val") == "hello" + + @pytest.mark.asyncio + async def test_async_workflow_error_handling_e2e(self): + """Test error handling in async workflows end-to-end.""" + instance_id = f"test_error_{int(time.time())}" + + # Start pre-registered workflow + self.client.schedule_new_orchestration( + type(self).error_handling_workflow, input="test_error_input", instance_id=instance_id + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + # Should have handled the error gracefully + assert result_data["status"] == "error" + assert "Activity failed with input: test_error_input" in result_data["error"] + + @pytest.mark.asyncio + async def test_async_workflow_with_external_event_e2e(self): + """Test async workflow with external events end-to-end.""" + instance_id = f"test_external_event_{int(time.time())}" + + # Start pre-registered workflow + self.client.schedule_new_orchestration( + type(self).external_event_workflow, input="test_event", instance_id=instance_id + ) + + # Give workflow time to start and wait for event + import asyncio + + await asyncio.sleep(1) + + # Send external event + self.client.raise_orchestration_event( + instance_id, "test_event", data={"message": "event_received"} + ) + + # Wait for completion + result = self.client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert result is not None + result_data = _deserialize_result(result) + + assert "Activity processed: initial" in result_data["initial"] + assert result_data["event_data"]["message"] == "event_received" + assert "Activity processed: event_" in result_data["final"] + assert "event_received" in result_data["final"] + + +class TestAsyncWorkflowPerformanceE2E: + """Performance tests for async workflows.""" + + @pytest.mark.e2e + @pytest.mark.asyncio + async def test_async_workflow_performance_baseline(self): + """Baseline performance test for async workflows.""" + # This test would measure execution time for various workflow patterns + # and ensure they meet performance requirements + + # For now, just ensure the test structure is in place + assert True # Placeholder + + @pytest.mark.e2e + @pytest.mark.asyncio + async def test_async_workflow_memory_usage(self): + """Test memory usage of async workflows.""" + # This test would monitor memory usage during workflow execution + # to ensure no memory leaks or excessive usage + + # For now, just ensure the test structure is in place + assert True # Placeholder + + +# Utility functions for E2E tests + + +def is_runtime_available(endpoint: str = "localhost:4001") -> bool: + """Check if DurableTask runtime is available at the given endpoint.""" + import socket + + try: + host, port = endpoint.split(":") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex((host, int(port))) + sock.close() + return result == 0 + except Exception: + return False + + +def skip_if_no_runtime(): + """Pytest fixture to skip tests if no runtime is available.""" + endpoint = os.getenv("DURABLETASK_GRPC_ENDPOINT", "localhost:4001") + if not is_runtime_available(endpoint): + pytest.skip(f"DurableTask runtime not available at {endpoint}") + + +def test_async_activity_retry_with_backoff(): + """Test that activities are retried with proper backoff and max attempts.""" + skip_if_no_runtime() + + from datetime import timedelta + + attempt_counter = 0 + + async def retry_orchestrator(ctx: AsyncWorkflowContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30), + ) + result = await ctx.call_activity(failing_activity, retry_policy=retry_policy) + return result + + def failing_activity(ctx, _): + nonlocal attempt_counter + attempt_counter += 1 + raise RuntimeError(f"Attempt {attempt_counter} failed") + + with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(retry_orchestrator) + worker.add_activity(failing_activity) + worker.start() + + client = TaskHubGrpcClient() + instance_id = client.schedule_new_orchestration(retry_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status.name == "FAILED" + assert state.failure_details is not None + assert "Attempt 3 failed" in state.failure_details.message + assert attempt_counter == 3 + + +def test_async_sub_orchestrator_retry(): + """Test that sub-orchestrators are retried on failure.""" + skip_if_no_runtime() + + from datetime import timedelta + + child_attempt_counter = 0 + activity_attempt_counter = 0 + + async def parent_orchestrator(ctx: AsyncWorkflowContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ) + result = await ctx.call_sub_orchestrator(child_orchestrator, retry_policy=retry_policy) + return result + + async def child_orchestrator(ctx: AsyncWorkflowContext, _): + nonlocal child_attempt_counter + if not ctx.is_replaying: + child_attempt_counter += 1 + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + ) + result = await ctx.call_activity(failing_activity, retry_policy=retry_policy) + return result + + def failing_activity(ctx, _): + nonlocal activity_attempt_counter + activity_attempt_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(parent_orchestrator) + worker.add_orchestrator(child_orchestrator) + worker.add_activity(failing_activity) + worker.start() + + client = TaskHubGrpcClient() + instance_id = client.schedule_new_orchestration(parent_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=40) + + assert state is not None + assert state.runtime_status.name == "FAILED" + assert state.failure_details is not None + # Each child orchestrator attempt retries the activity 3 times + # 3 child attempts × 3 activity attempts = 9 total + assert activity_attempt_counter == 9 + assert child_attempt_counter == 3 + + +def test_async_retry_timeout(): + """Test that retry timeout limits the number of attempts.""" + skip_if_no_runtime() + + from datetime import timedelta + + attempt_counter = 0 + + async def timeout_orchestrator(ctx: AsyncWorkflowContext, _): + # Max 5 attempts, but timeout at 14 seconds + # Attempts: 1s + 2s + 4s + 8s = 15s, so only 4 attempts should happen + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=14), + ) + result = await ctx.call_activity(failing_activity, retry_policy=retry_policy) + return result + + def failing_activity(ctx, _): + nonlocal attempt_counter + attempt_counter += 1 + raise RuntimeError(f"Attempt {attempt_counter} failed") + + with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(timeout_orchestrator) + worker.add_activity(failing_activity) + worker.start() + + client = TaskHubGrpcClient() + instance_id = client.schedule_new_orchestration(timeout_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=40) + + assert state is not None + assert state.runtime_status.name == "FAILED" + # Should only attempt 4 times due to timeout (1s + 2s + 4s + 8s would exceed 14s) + assert attempt_counter == 4 + + +def test_async_non_retryable_error(): + """Test that NonRetryableError prevents retries.""" + skip_if_no_runtime() + + from datetime import timedelta + + attempt_counter = 0 + + async def non_retryable_orchestrator(ctx: AsyncWorkflowContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=1, + ) + result = await ctx.call_activity(non_retryable_activity, retry_policy=retry_policy) + return result + + def non_retryable_activity(ctx, _): + nonlocal attempt_counter + attempt_counter += 1 + raise task.NonRetryableError("This should not be retried") + + with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(non_retryable_orchestrator) + worker.add_activity(non_retryable_activity) + worker.start() + + client = TaskHubGrpcClient() + instance_id = client.schedule_new_orchestration(non_retryable_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=20) + + assert state is not None + assert state.runtime_status.name == "FAILED" + assert state.failure_details is not None + assert "NonRetryableError" in state.failure_details.error_type + # Should only attempt once since it's non-retryable + assert attempt_counter == 1 + + +def test_async_successful_retry(): + """Test that an activity succeeds after retries.""" + skip_if_no_runtime() + + from datetime import timedelta + + attempt_counter = 0 + + async def successful_retry_orchestrator(ctx: AsyncWorkflowContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=1, + ) + result = await ctx.call_activity(eventually_succeeds_activity, retry_policy=retry_policy) + return result + + def eventually_succeeds_activity(ctx, _): + nonlocal attempt_counter + attempt_counter += 1 + if attempt_counter < 3: + raise RuntimeError(f"Attempt {attempt_counter} failed") + return f"Success on attempt {attempt_counter}" + + with TaskHubGrpcWorker() as worker: + worker.add_orchestrator(successful_retry_orchestrator) + worker.add_activity(eventually_succeeds_activity) + worker.start() + + client = TaskHubGrpcClient() + instance_id = client.schedule_new_orchestration(successful_retry_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=30) + + assert state is not None + assert state.runtime_status.name == "COMPLETED" + assert state.serialized_output == '"Success on attempt 3"' + assert attempt_counter == 3 + + +def test_async_suspend_and_resume_e2e(): + import os + + async def orch(ctx, _): + val = await ctx.wait_for_external_event("x") + return val + + # Respect pre-configured endpoint; default only if not set + os.environ.setdefault("DURABLETASK_GRPC_ENDPOINT", "localhost:4001") + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orch) + w.start() + + with client.TaskHubGrpcClient() as c: + id = c.schedule_new_orchestration(orch) + state = c.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING + + # Suspend then ensure it goes to SUSPENDED + c.suspend_orchestration(id) + while True: + st = c.get_orchestration_state(id) + assert st is not None + if st.runtime_status == client.OrchestrationStatus.SUSPENDED: + break + time.sleep(0.1) + + # Raise event while suspended, then resume and expect completion + c.raise_orchestration_event(id, "x", data=42) + c.resume_orchestration(id) + + state = c.wait_for_orchestration_completion(id, timeout=30, fetch_payloads=True) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) + + +def test_async_sub_orchestrator_e2e(): + async def child(ctx, x: int): + return x + 1 + + async def parent(ctx, x: int): + y = await ctx.call_sub_orchestrator(child, input=x) + return y * 2 + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(child) + w.add_orchestrator(parent) + w.start() + + with client.TaskHubGrpcClient() as c: + id = c.schedule_new_orchestration(parent, input=3) + + state = c.wait_for_orchestration_completion(id, timeout=30, fetch_payloads=True) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_output == json.dumps(8) + + +def test_now_with_sequence_ordering_e2e(): + """ + Test that now_with_sequence() maintains strict ordering across workflow execution. + + This verifies: + 1. Timestamps increment sequentially + 2. Order is preserved across activity calls + 3. Deterministic behavior (timestamps are consistent on replay) + """ + + def simple_activity(ctx, input_val: str): + return f"activity_{input_val}_done" + + async def timestamp_ordering_workflow(ctx, _): + timestamps = [] + + # First timestamp before any activities + t1 = ctx.now_with_sequence() + timestamps.append(("t1_before_activities", t1.isoformat())) + + # Call first activity + result1 = await ctx.call_activity(simple_activity, input="first") + timestamps.append(("activity_1_result", result1)) + + # Timestamp after first activity + t2 = ctx.now_with_sequence() + timestamps.append(("t2_after_activity_1", t2.isoformat())) + + # Call second activity + result2 = await ctx.call_activity(simple_activity, input="second") + timestamps.append(("activity_2_result", result2)) + + # Timestamp after second activity + t3 = ctx.now_with_sequence() + timestamps.append(("t3_after_activity_2", t3.isoformat())) + + # A few more rapid timestamps to test counter incrementing + t4 = ctx.now_with_sequence() + timestamps.append(("t4_rapid", t4.isoformat())) + + t5 = ctx.now_with_sequence() + timestamps.append(("t5_rapid", t5.isoformat())) + + t6 = ctx.now_with_sequence() + timestamps.append(("t6_rapid", t6.isoformat())) + + # Return all timestamps for verification + return { + "timestamps": timestamps, + "t1": t1.isoformat(), + "t2": t2.isoformat(), + "t3": t3.isoformat(), + "t4": t4.isoformat(), + "t5": t5.isoformat(), + "t6": t6.isoformat(), + } + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(timestamp_ordering_workflow) + w.add_activity(simple_activity) + w.start() + + with client.TaskHubGrpcClient() as c: + instance_id = c.schedule_new_orchestration(timestamp_ordering_workflow) + state = c.wait_for_orchestration_completion( + instance_id, timeout=30, fetch_payloads=True + ) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + # Parse result + result = _deserialize_result(state) + assert result is not None + + # Verify all timestamps are present + assert "t1" in result + assert "t2" in result + assert "t3" in result + assert "t4" in result + assert "t5" in result + assert "t6" in result + + # Parse timestamps back to datetime objects for comparison + from datetime import datetime + + t1 = datetime.fromisoformat(result["t1"]) + t2 = datetime.fromisoformat(result["t2"]) + t3 = datetime.fromisoformat(result["t3"]) + t4 = datetime.fromisoformat(result["t4"]) + t5 = datetime.fromisoformat(result["t5"]) + t6 = datetime.fromisoformat(result["t6"]) + + # Verify strict ordering: t1 < t2 < t3 < t4 < t5 + # This is the key guarantee - timestamps must maintain order for tracing + assert t1 < t2, f"t1 ({t1}) should be < t2 ({t2})" + assert t2 < t3, f"t2 ({t2}) should be < t3 ({t3})" + assert t3 < t4, f"t3 ({t3}) should be < t4 ({t4})" + assert t4 < t5, f"t4 ({t4}) should be < t5 ({t5})" + assert t5 < t6, f"t5 ({t5}) should be < t6 ({t6})" + + # Verify that timestamps called in rapid succession (t3, t4, t5 with no activities between) + # have exactly 1 microsecond deltas. These happen within the same replay execution. + delta_t3_t4 = (t4 - t3).total_seconds() * 1_000_000 + delta_t4_t5 = (t5 - t4).total_seconds() * 1_000_000 + delta_t5_t6 = (t6 - t5).total_seconds() * 1_000_000 + + assert delta_t3_t4 == 1.0, f"t3 to t4 should be 1 microsecond, got {delta_t3_t4}" + assert delta_t4_t5 == 1.0, f"t4 to t5 should be 1 microsecond, got {delta_t4_t5}" + assert delta_t5_t6 == 1.0, f"t5 to t6 should be 1 microsecond, got {delta_t5_t6}" + + # Note: We don't check exact deltas for t1->t2 or t2->t3 because they span + # activity calls. During replay, current_utc_datetime changes based on event + # timestamps, so the base time shifts. However, ordering is still guaranteed. diff --git a/tests/aio/test_integration.py b/tests/aio/test_integration.py new file mode 100644 index 0000000..4cf3848 --- /dev/null +++ b/tests/aio/test_integration.py @@ -0,0 +1,721 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integration tests for durabletask.aio package. + +These tests verify end-to-end functionality of async workflows, +including the interaction between all components. + +Tests marked with @pytest.mark.e2e require a running Dapr sidecar +or DurableTask-Go emulator and are skipped by default. +""" + +from datetime import datetime +from unittest.mock import Mock + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + AsyncWorkflowContext, + AsyncWorkflowError, + CoroutineOrchestratorRunner, + WorkflowTimeoutError, +) + + +class FakeTask(dt_task.Task): + """Simple fake task for testing, based on python-sdk approach.""" + + def __init__(self, name: str): + super().__init__() + self.name = name + self._result = f"result_for_{name}" + + def get_result(self): + return self._result + + def complete_with_result(self, result): + """Helper method for tests to complete the task.""" + self._result = result + self._is_complete = True + + +class FakeCtx: + """Simple fake context for testing, based on python-sdk approach.""" + + def __init__(self): + self.current_utc_datetime = datetime(2024, 1, 1, 12, 0, 0) + self.instance_id = "test-instance" + self.is_replaying = False + self.workflow_name = "test-workflow" + self.is_suspended = False + + def call_activity(self, activity, *, input=None, retry_policy=None, metadata=None): + activity_name = getattr(activity, "__name__", str(activity)) + return FakeTask(f"activity:{activity_name}") + + def call_sub_orchestrator( + self, orchestrator, *, input=None, instance_id=None, retry_policy=None, metadata=None + ): + orchestrator_name = getattr(orchestrator, "__name__", str(orchestrator)) + return FakeTask(f"sub:{orchestrator_name}") + + def create_timer(self, fire_at): + return FakeTask("timer") + + def wait_for_external_event(self, name: str): + return FakeTask(f"event:{name}") + + def set_custom_status(self, custom_status): + pass + + def continue_as_new(self, new_input, *, save_events=False): + pass + + +def drive_workflow(gen, results_map=None): + """ + Drive a workflow generator, providing results for yielded tasks. + Based on python-sdk approach but adapted for durabletask. + + Args: + gen: The workflow generator + results_map: Dict mapping task names to results, or callable that takes task and returns result + """ + results_map = results_map or {} + + try: + # Start the generator + task = next(gen) + + while True: + # Determine result for this task + if callable(results_map): + result = results_map(task) + elif hasattr(task, "name"): + result = results_map.get(task.name, f"result_for_{task.name}") + else: + result = "default_result" + + # Send result and get next task + task = gen.send(result) + + except StopIteration as stop: + return stop.value + + +class TestAsyncWorkflowIntegration: + """Integration tests for async workflow functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fake_ctx = FakeCtx() + + def test_simple_activity_workflow_integration(self): + """Test a simple workflow that calls one activity.""" + + async def simple_activity_workflow(ctx: AsyncWorkflowContext, input_data: str) -> str: + result = await ctx.call_activity("process_data", input=input_data) + return f"Processed: {result}" + + # Create runner and context + runner = CoroutineOrchestratorRunner(simple_activity_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + # Execute workflow using the drive helper + gen = runner.to_generator(async_ctx, "test_input") + result = drive_workflow(gen, {"activity:process_data": "activity_result"}) + + assert result == "Processed: activity_result" + + def test_multi_step_workflow_integration(self): + """Test a workflow with multiple sequential activities.""" + + async def multi_step_workflow(ctx: AsyncWorkflowContext, input_data: dict) -> dict: + # Step 1: Validate input + validation_result = await ctx.call_activity("validate_input", input=input_data) + + # Step 2: Process data + processing_result = await ctx.call_activity("process_data", input=validation_result) + + # Step 3: Save result + save_result = await ctx.call_activity("save_result", input=processing_result) + + return { + "validation": validation_result, + "processing": processing_result, + "save": save_result, + } + + runner = CoroutineOrchestratorRunner(multi_step_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, {"data": "test"}) + + # Use drive_workflow with specific results for each activity + results_map = { + "activity:validate_input": "validated_data", + "activity:process_data": "processed_data", + "activity:save_result": "saved_data", + } + result = drive_workflow(gen, results_map) + + assert result == { + "validation": "validated_data", + "processing": "processed_data", + "save": "saved_data", + } + + def test_parallel_activities_workflow_integration(self): + """Test a workflow with parallel activities using when_all.""" + + async def parallel_workflow(ctx: AsyncWorkflowContext, input_data: list) -> list: + # Start multiple activities in parallel + tasks = [] + for i, item in enumerate(input_data): + task = ctx.call_activity(f"process_item_{i}", input=item) + tasks.append(task) + + # Wait for all to complete + results = await ctx.when_all(tasks) + return results + + runner = CoroutineOrchestratorRunner(parallel_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + input_data = ["item1", "item2", "item3"] + gen = runner.to_generator(async_ctx, input_data) + + # Use drive_workflow to handle the when_all task + result = drive_workflow(gen, lambda task: ["result1", "result2", "result3"]) + + assert result == ["result1", "result2", "result3"] + + def test_sub_orchestrator_workflow_integration(self): + """Test a workflow that calls a sub-orchestrator.""" + + async def parent_workflow(ctx: AsyncWorkflowContext, input_data: dict) -> dict: + # Call sub-orchestrator + sub_result = await ctx.call_sub_orchestrator( + "child_workflow", input=input_data["child_input"], instance_id="child-instance" + ) + + # Process sub-orchestrator result + final_result = await ctx.call_activity("finalize", input=sub_result) + + return {"sub_result": sub_result, "final": final_result} + + runner = CoroutineOrchestratorRunner(parent_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, {"child_input": "test_data"}) + + # Use drive_workflow with specific results + results_map = { + "sub:child_workflow": "sub_orchestrator_result", + "activity:finalize": "final_result", + } + result = drive_workflow(gen, results_map) + + assert result == {"sub_result": "sub_orchestrator_result", "final": "final_result"} + + def test_timer_workflow_integration(self): + """Test a workflow that uses timers.""" + + async def timer_workflow(ctx: AsyncWorkflowContext, delay_seconds: float) -> str: + # Start some work + initial_result = await ctx.call_activity("start_work", input="begin") + + # Wait for specified delay + await ctx.create_timer(delay_seconds) + + # Complete work + final_result = await ctx.call_activity("complete_work", input=initial_result) + + return final_result + + runner = CoroutineOrchestratorRunner(timer_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, 30.0) + + # Use drive_workflow with specific results + results_map = { + "activity:start_work": "work_started", + "timer": None, # Timer completion + "activity:complete_work": "work_completed", + } + result = drive_workflow(gen, results_map) + + assert result == "work_completed" + + def test_external_event_workflow_integration(self): + """Test a workflow that waits for external events.""" + + async def event_workflow(ctx: AsyncWorkflowContext, event_name: str) -> dict: + # Start processing + start_result = await ctx.call_activity("start_processing", input="begin") + + # Wait for external event + event_data = await ctx.wait_for_external_event(event_name) + + # Process event data + final_result = await ctx.call_activity( + "process_event", input={"start": start_result, "event": event_data} + ) + + return {"result": final_result, "event_data": event_data} + + runner = CoroutineOrchestratorRunner(event_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, "approval_event") + + # Use drive_workflow with specific results + results_map = { + "activity:start_processing": "processing_started", + "event:approval_event": {"approved": True, "user": "admin"}, + "activity:process_event": "event_processed", + } + result = drive_workflow(gen, results_map) + + assert result == { + "result": "event_processed", + "event_data": {"approved": True, "user": "admin"}, + } + + def test_when_any_workflow_integration(self): + """Test a workflow using when_any for racing conditions.""" + + async def racing_workflow(ctx: AsyncWorkflowContext, timeout_seconds: float) -> dict: + # Start a long-running activity + work_task = ctx.call_activity("long_running_work", input="start") + + # Create a timeout + timeout_task = ctx.create_timer(timeout_seconds) + + # Race between work completion and timeout + idx, result = await ctx.when_any([work_task, timeout_task]) + + if idx == 0: + return {"status": "completed", "result": result} + else: + return {"status": "timeout", "result": None} + + runner = CoroutineOrchestratorRunner(racing_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, 10.0) + + # Should yield when_any task + when_any_task = next(gen) + assert isinstance(when_any_task, dt_task.Task) + + # Simulate work completing first + mock_completed_task = Mock() + mock_completed_task.get_result.return_value = "work_done" + + try: + gen.send(mock_completed_task) + except StopIteration as stop: + result = stop.value + assert result["status"] == "completed" + assert result["result"] == "work_done" + + def test_timeout_workflow_integration(self): + """Test workflow with timeout functionality.""" + + async def timeout_workflow(ctx: AsyncWorkflowContext, data: str) -> str: + try: + # Activity with 5-second timeout + result = await ctx.with_timeout( + ctx.call_activity("slow_activity", input=data), + 5.0, + ) + return f"Success: {result}" + except WorkflowTimeoutError: + return "Timeout occurred" + + runner = CoroutineOrchestratorRunner(timeout_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, "test_data") + + # Should yield when_any task (activity vs timeout) + when_any_task = next(gen) + assert isinstance(when_any_task, dt_task.Task) + + # Simulate timeout completing first + timeout_task = Mock() + timeout_task.get_result.return_value = None + + try: + gen.send(timeout_task) + except StopIteration as stop: + assert stop.value == "Timeout occurred" + + def test_deterministic_operations_integration(self): + """Test that deterministic operations work correctly in workflows.""" + + async def deterministic_workflow(ctx: AsyncWorkflowContext, count: int) -> dict: + # Generate deterministic random values + random_values = [] + for _ in range(count): + rng = ctx.random() + random_values.append(rng.random()) + + # Generate deterministic UUIDs + uuids = [] + for _ in range(count): + uuids.append(str(ctx.uuid4())) + + # Generate deterministic strings + strings = [] + for i in range(count): + strings.append(ctx.random_string(10)) + + return { + "random_values": random_values, + "uuids": uuids, + "strings": strings, + "timestamp": ctx.now().isoformat(), + } + + runner = CoroutineOrchestratorRunner(deterministic_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, 3) + + # Should complete synchronously (no async operations) + try: + next(gen) + except StopIteration as stop: + result = stop.value + + # Verify structure + assert len(result["random_values"]) == 3 + assert len(result["uuids"]) == 3 + assert len(result["strings"]) == 3 + assert "timestamp" in result + + # Verify deterministic behavior - run again with fresh context (simulates replay) + async_ctx2 = AsyncWorkflowContext(self.fake_ctx) + gen2 = runner.to_generator(async_ctx2, 3) + try: + next(gen2) + except StopIteration as stop2: + result2 = stop2.value + + # Should be identical (deterministic behavior) + assert result == result2 + + def test_error_handling_integration(self): + """Test error handling throughout the workflow stack.""" + + async def error_prone_workflow(ctx: AsyncWorkflowContext, should_fail: bool) -> str: + if should_fail: + raise ValueError("Workflow intentionally failed") + + result = await ctx.call_activity("safe_activity", input="test") + return f"Success: {result}" + + runner = CoroutineOrchestratorRunner(error_prone_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + # Test successful case + gen_success = runner.to_generator(async_ctx, False) + _ = next(gen_success) + + try: + gen_success.send("activity_result") + except StopIteration as stop: + assert stop.value == "Success: activity_result" + + # Test error case + gen_error = runner.to_generator(async_ctx, True) + + with pytest.raises(AsyncWorkflowError) as exc_info: + next(gen_error) + + assert "Workflow intentionally failed" in str(exc_info.value) + assert exc_info.value.workflow_name == "error_prone_workflow" + + def test_sandbox_integration(self): + """Test sandbox integration with workflows.""" + + async def sandbox_workflow(ctx: AsyncWorkflowContext, mode: str) -> dict: + # Use deterministic operations + random_val = ctx.random().random() + uuid_val = str(ctx.uuid4()) + time_val = ctx.now().isoformat() + + # Call an activity + activity_result = await ctx.call_activity("test_activity", input="test") + + return { + "random": random_val, + "uuid": uuid_val, + "time": time_val, + "activity": activity_result, + } + + # Test with different sandbox modes + for mode in ["off", "best_effort", "strict"]: + runner = CoroutineOrchestratorRunner(sandbox_workflow, sandbox_mode=mode) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + gen = runner.to_generator(async_ctx, mode) + + # Should yield activity task + activity_task = next(gen) + # With FakeCtx, ensure we yielded the expected durable task token + assert isinstance(activity_task, dt_task.Task) + assert getattr(activity_task, "name", "") == "activity:test_activity" + + # Complete workflow + try: + gen.send("activity_done") + except StopIteration as stop: + result = stop.value + + # Verify structure + assert "random" in result + assert "uuid" in result + assert "time" in result + assert result["activity"] == "activity_done" + + def test_complex_workflow_integration(self): + """Test a complex workflow combining multiple features.""" + + async def complex_workflow(ctx: AsyncWorkflowContext, config: dict) -> dict: + # Step 1: Initialize + init_result = await ctx.call_activity("initialize", input=config) + + # Step 2: Parallel processing + parallel_tasks = [] + for i in range(config["parallel_count"]): + task = ctx.call_activity(f"process_batch_{i}", input=init_result) + parallel_tasks.append(task) + + batch_results = await ctx.when_all(parallel_tasks) + + # Step 3: Wait for approval with timeout + try: + approval = await ctx.with_timeout( + ctx.wait_for_external_event("approval"), + config["approval_timeout"], + ) + except WorkflowTimeoutError: + approval = {"approved": False, "reason": "timeout"} + + # Step 4: Conditional sub-orchestrator + if approval.get("approved", False): + sub_result = await ctx.call_sub_orchestrator( + "finalization_workflow", input={"batches": batch_results, "approval": approval} + ) + else: + sub_result = await ctx.call_activity("handle_rejection", input=approval) + + # Step 5: Generate report + report = { + "workflow_id": ctx.instance_id, + "timestamp": ctx.now().isoformat(), + "init": init_result, + "batches": batch_results, + "approval": approval, + "final": sub_result, + "random_id": str(ctx.uuid4()), + } + + return report + + runner = CoroutineOrchestratorRunner(complex_workflow) + async_ctx = AsyncWorkflowContext(self.fake_ctx) + + config = {"parallel_count": 2, "approval_timeout": 30.0} + + gen = runner.to_generator(async_ctx, config) + + # Step 1: Initialize + _ = next(gen) + + # Step 2: Parallel processing (when_all) + _ = gen.send("initialized") + + # Step 3: Approval with timeout (when_any) + _ = gen.send(["batch_1_result", "batch_2_result"]) + + # Simulate approval received + approval_data = {"approved": True, "user": "admin"} + + # Step 4: Sub-orchestrator + _ = gen.send(approval_data) + + # Complete workflow + try: + gen.send("finalization_complete") + except StopIteration as stop: + result = stop.value + + # Verify complex result structure + assert result["workflow_id"] == "test-instance" + assert result["init"] == "initialized" + assert result["batches"] == ["batch_1_result", "batch_2_result"] + assert result["approval"] == approval_data + assert result["final"] == "finalization_complete" + assert "timestamp" in result + assert "random_id" in result + + def test_workflow_replay_determinism(self): + """Test that workflows are deterministic during replay.""" + + async def replay_test_workflow(ctx: AsyncWorkflowContext, input_data: str) -> dict: + # Generate deterministic values + random_val = ctx.random().random() + uuid_val = str(ctx.uuid4()) + string_val = ctx.random_string(8) + + # Call activity + activity_result = await ctx.call_activity("test_activity", input=input_data) + + return { + "random": random_val, + "uuid": uuid_val, + "string": string_val, + "activity": activity_result, + } + + runner = CoroutineOrchestratorRunner(replay_test_workflow) + + # First execution + async_ctx1 = AsyncWorkflowContext(self.fake_ctx) + gen1 = runner.to_generator(async_ctx1, "test_input") + + _ = next(gen1) + + try: + gen1.send("activity_result") + except StopIteration as stop1: + result1 = stop1.value + + # Second execution (simulating replay) + async_ctx2 = AsyncWorkflowContext(self.fake_ctx) + gen2 = runner.to_generator(async_ctx2, "test_input") + + _ = next(gen2) + + try: + gen2.send("activity_result") + except StopIteration as stop2: + result2 = stop2.value + + # Results should be identical (deterministic) + assert result1 == result2 + + +class TestSandboxIntegration: + """Integration tests for sandbox functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock() + self.mock_base_ctx.instance_id = "test-instance" + self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0) + self.mock_base_ctx.call_activity.return_value = Mock(spec=dt_task.Task) + self.mock_base_ctx.create_timer.return_value = Mock(spec=dt_task.Task) + + def test_sandbox_with_async_workflow_context(self): + """Test sandbox integration with AsyncWorkflowContext.""" + import random + import time + import uuid + + from durabletask.aio.sandbox import _sandbox_scope + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Should work with real AsyncWorkflowContext + test_random = random.random() + test_uuid = uuid.uuid4() + test_time = time.time() + + assert isinstance(test_random, float) + assert isinstance(test_uuid, uuid.UUID) + assert isinstance(test_time, float) + + def test_sandbox_warning_detection(self): + """Test that sandbox properly issues warnings.""" + import warnings + + from durabletask.aio import NonDeterminismWarning + from durabletask.aio.sandbox import _sandbox_scope + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + with _sandbox_scope(async_ctx, "best_effort"): + # This should potentially trigger warnings if non-deterministic calls are detected + pass + + # Check if any NonDeterminismWarning was issued + # May or may not have warnings depending on implementation + _ = [warning for warning in w if issubclass(warning.category, NonDeterminismWarning)] + + def test_sandbox_performance_impact(self): + """Test that sandbox doesn't have excessive performance impact.""" + import random + import time as time_module + + from durabletask.aio.sandbox import _sandbox_scope + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + # Ensure debug mode is OFF for performance testing + async_ctx._debug_mode = False + + # Measure without sandbox + start = time_module.perf_counter() + for _ in range(1000): + random.random() + no_sandbox_time = time_module.perf_counter() - start + + # Measure with sandbox + start = time_module.perf_counter() + with _sandbox_scope(async_ctx, "best_effort"): + for _ in range(1000): + random.random() + sandbox_time = time_module.perf_counter() - start + + # Sandbox should not be more than 20x slower (reasonable overhead for patching + minimal tracing) + # In practice, the overhead comes from function call interception and deterministic RNG + assert sandbox_time < no_sandbox_time * 20, ( + f"Sandbox: {sandbox_time:.6f}s, No sandbox: {no_sandbox_time:.6f}s" + ) + + def test_sandbox_mode_validation(self): + """Test sandbox mode validation.""" + from durabletask.aio.sandbox import _sandbox_scope + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Valid modes should work + for mode in ["off", "best_effort", "strict"]: + with _sandbox_scope(async_ctx, mode): + pass + + # Invalid mode should raise error + with pytest.raises(ValueError): + with _sandbox_scope(async_ctx, "invalid"): + pass diff --git a/tests/aio/test_non_determinism_detection.py b/tests/aio/test_non_determinism_detection.py new file mode 100644 index 0000000..348cd4f --- /dev/null +++ b/tests/aio/test_non_determinism_detection.py @@ -0,0 +1,351 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for non-determinism detection in async workflows. +""" + +import datetime +import warnings +from unittest.mock import Mock + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import ( + AsyncWorkflowContext, + NonDeterminismWarning, + SandboxViolationError, + _NonDeterminismDetector, +) +from durabletask.aio.sandbox import _sandbox_scope + + +class TestNonDeterminismDetection: + """Test non-determinism detection and warnings.""" + + def setup_method(self): + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + self.async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + def test_non_determinism_detector_context_manager(self): + """Test that the detector can be used as a context manager.""" + detector = _NonDeterminismDetector(self.async_ctx, "best_effort") + + with detector: + # Should not raise + pass + + def test_deterministic_alternative_suggestions(self): + """Test that appropriate alternatives are suggested.""" + detector = _NonDeterminismDetector(self.async_ctx, "best_effort") + + test_cases = [ + ("datetime.now", "ctx.now()"), + ("datetime.utcnow", "ctx.now()"), + ("time.time", "ctx.now().timestamp()"), + ("random.random", "ctx.random().random()"), + ("uuid.uuid4", "ctx.uuid4()"), + ("os.urandom", "ctx.random().randbytes() or ctx.random().getrandbits()"), + ("unknown.function", "a deterministic alternative"), + ] + + for call_sig, expected in test_cases: + result = detector._get_deterministic_alternative(call_sig) + assert result == expected + + def test_sandbox_with_non_determinism_detection_off(self): + """Test that detection is disabled when mode is 'off'.""" + with _sandbox_scope(self.async_ctx, "off"): + # Should not detect anything + import datetime as dt + + # This would normally trigger detection, but mode is off + current_time = dt.datetime.now() + assert current_time is not None + + def test_sandbox_with_non_determinism_detection_best_effort(self): + """Test that detection works in best_effort mode.""" + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + + with _sandbox_scope(self.async_ctx, "best_effort"): + # This should work without issues since we're just testing the context + pass + + # Note: The actual detection happens during function execution tracing + # which is complex to test in isolation + + def test_sandbox_with_non_determinism_detection_strict_mode(self): + """Test that strict mode blocks dangerous operations.""" + with pytest.raises(SandboxViolationError, match="File I/O operations are not allowed"): + with _sandbox_scope(self.async_ctx, "strict"): + open("test.txt", "w") + + def test_non_determinism_warning_class(self): + """Test that NonDeterminismWarning is a proper warning class.""" + warning = NonDeterminismWarning("Test warning") + assert isinstance(warning, UserWarning) + assert str(warning) == "Test warning" + + def test_detector_deduplication(self): + """Test that the detector doesn't warn about the same call multiple times.""" + detector = _NonDeterminismDetector(self.async_ctx, "best_effort") + + # Simulate multiple calls to the same function + detector.detected_calls.add("datetime.now") + + # This should not add a duplicate + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Create a mock frame for the call + mock_frame = Mock() + mock_frame.f_code.co_filename = "test.py" + mock_frame.f_lineno = 10 + detector._handle_non_deterministic_call("datetime.now", mock_frame) + + # Should not have issued a warning since it was already detected + assert len(w) == 0 + + def test_detector_strict_mode_raises_error(self): + """Test that strict mode raises AsyncWorkflowError instead of warning.""" + detector = _NonDeterminismDetector(self.async_ctx, "strict") + + with pytest.raises(SandboxViolationError) as exc_info: + # Create a mock frame for the call + mock_frame = Mock() + mock_frame.f_code.co_filename = "test.py" + mock_frame.f_lineno = 10 + detector._handle_non_deterministic_call("datetime.now", mock_frame) + + error = exc_info.value + assert "Non-deterministic function 'datetime.now' is not allowed" in str(error) + assert error.instance_id == "test-instance-123" + + def test_detector_logs_to_debug_info(self): + """Test that warnings are logged to debug info when debug mode is enabled.""" + # Enable debug mode + self.async_ctx._debug_mode = True + + detector = _NonDeterminismDetector(self.async_ctx, "best_effort") + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + # Create a mock frame for the call + mock_frame = Mock() + mock_frame.f_code.co_filename = "test.py" + mock_frame.f_lineno = 10 + detector._handle_non_deterministic_call("datetime.now", mock_frame) + + # Check that debug message was printed (our current implementation just prints) + # The current implementation doesn't log to operation_history, it just prints debug messages + # This is acceptable behavior for debug mode + + +class TestNonDeterminismIntegration: + """Integration tests for non-determinism detection with actual workflows.""" + + def setup_method(self): + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + def test_sandbox_patches_work_correctly(self): + """Test that the sandbox patches actually work.""" + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + import random + import time + import uuid + + # These should use deterministic versions + random_val = random.random() + uuid_val = uuid.uuid4() + time_val = time.time() + + # Values should be deterministic + assert isinstance(random_val, float) + assert isinstance(uuid_val, uuid.UUID) + assert isinstance(time_val, float) + + def test_datetime_limitation_documented(self): + """Test that datetime.now() limitation is properly documented.""" + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + import datetime as dt + + # datetime.now() cannot be patched due to immutability + # This should return the actual current time, not the deterministic time + now_result = dt.datetime.now() + deterministic_time = async_ctx.now() + + # They will likely be different (unless run at exactly the same time) + # This documents the limitation + assert isinstance(now_result, datetime.datetime) + assert isinstance(deterministic_time, datetime.datetime) + + def test_rng_whitelist_and_global_random_determinism(self): + """ctx.random() methods allowed; global random.* is patched to deterministic in strict/best_effort.""" + import random + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Strict: ctx.random().randint allowed + with _sandbox_scope(async_ctx, "strict"): + rng = async_ctx.random() + assert isinstance(rng.randint(1, 3), int) + + # Strict: global random.randint patched and deterministic + with _sandbox_scope(async_ctx, "strict"): + v1 = random.randint(1, 1000000) + with _sandbox_scope(async_ctx, "strict"): + v2 = random.randint(1, 1000000) + assert v1 == v2 + + # Best-effort: global random warns but returns + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + with _sandbox_scope(async_ctx, "best_effort"): + val1 = random.random() + with _sandbox_scope(async_ctx, "best_effort"): + val2 = random.random() + assert isinstance(val1, float) + assert val1 == val2 + # Note: we intentionally don't assert on collected warnings here to keep the test + # resilient across environments where tracing may not capture stdlib frames. + + def test_uuid_and_os_urandom_strict_behavior(self): + """uuid.uuid4 is patched to deterministic; os.urandom is blocked in strict; ctx.uuid4 allowed.""" + import os + import uuid as _uuid + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + # Allowed via deterministic helper + with _sandbox_scope(async_ctx, "strict"): + val = async_ctx.uuid4() + assert isinstance(val, _uuid.UUID) + + # Patched global uuid.uuid4 is deterministic + with _sandbox_scope(async_ctx, "strict"): + u1 = _uuid.uuid4() + with _sandbox_scope(async_ctx, "strict"): + u2 = _uuid.uuid4() + assert isinstance(u1, _uuid.UUID) + assert u1 == u2 + + if hasattr(os, "urandom"): + with pytest.raises(SandboxViolationError): + with _sandbox_scope(async_ctx, "strict"): + _ = os.urandom(8) + + @pytest.mark.asyncio + async def test_create_task_blocked_in_strict_and_closed_coroutines(self): + """asyncio.create_task is blocked in strict; ensure no unawaited coroutine warning leaks.""" + import asyncio + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def dummy(): + return 42 + + # Blocked in strict + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with pytest.raises(SandboxViolationError): + with _sandbox_scope(async_ctx, "strict"): + asyncio.create_task(dummy()) + # Ensure no "coroutine was never awaited" RuntimeWarning leaked + assert not any("was never awaited" in str(rec.message) for rec in w) + + # Also blocked when passing a ready Future + fut = asyncio.get_event_loop().create_future() + fut.set_result(1) + with pytest.raises(SandboxViolationError): + with _sandbox_scope(async_ctx, "strict"): + asyncio.create_task(fut) # type: ignore[arg-type] + + @pytest.mark.asyncio + async def test_create_task_allowed_in_best_effort(self): + """In best_effort mode, create_task should be allowed and runnable.""" + import asyncio + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def quick(): + # sleep(0) is passed through to original sleep in sandbox + await asyncio.sleep(0) + return "ok" + + with _sandbox_scope(async_ctx, "best_effort"): + t = asyncio.create_task(quick()) + assert await t == "ok" + + def test_helper_methods_allowed_in_strict(self): + """Ensure helper methods use whitelisted deterministic RNG in strict mode.""" + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "strict"): + s = async_ctx.random_string(5) + assert len(s) == 5 + n = async_ctx.random_int(1, 3) + assert 1 <= n <= 3 + choice = async_ctx.random_choice(["a", "b", "c"]) + assert choice in {"a", "b", "c"} + + @pytest.mark.asyncio + async def test_gather_variants_and_caching(self): + """Exercise patched asyncio.gather paths: empty, all-workflow, mixed with return_exceptions, and caching.""" + import asyncio + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Empty gather returns [], cache replay on re-await + g0 = asyncio.gather() + r0a = await g0 + r0b = await g0 + assert r0a == [] and r0b == [] + + # All workflow awaitables (sleep -> WhenAll path) + a1 = async_ctx.create_timer(0) + a2 = async_ctx.create_timer(0) + g1 = asyncio.gather(a1, a2) + # Do not await g1: constructing it covers the all-workflow branch without + # requiring a real orchestrator; ensure it is awaitable (one-shot wrapper) + assert hasattr(g1, "__await__") + + # Mixed inputs with return_exceptions True + async def boom(): + raise RuntimeError("x") + + async def small(): + await asyncio.sleep(0) + return "ok" + + # Mixed native coroutines path (no workflow awaitables) + g2 = asyncio.gather(small(), boom(), return_exceptions=True) + r2 = await g2 + assert len(r2) == 2 and isinstance(r2[1], Exception) + + def test_invalid_mode_raises(self): + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + with pytest.raises(ValueError): + with _sandbox_scope(async_ctx, "invalid_mode"): + pass + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/aio/test_sandbox.py b/tests/aio/test_sandbox.py new file mode 100644 index 0000000..f5876f0 --- /dev/null +++ b/tests/aio/test_sandbox.py @@ -0,0 +1,1698 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for sandbox functionality in durabletask.aio. +""" + +import asyncio +import datetime +import os +import random +import secrets +import time +import uuid +import warnings +from unittest.mock import Mock, patch + +import pytest + +from durabletask import task as dt_task +from durabletask.aio import NonDeterminismWarning, _NonDeterminismDetector +from durabletask.aio.errors import AsyncWorkflowError +from durabletask.aio.sandbox import _sandbox_scope + + +class TestNonDeterminismDetector: + """Test NonDeterminismWarning and _NonDeterminismDetector functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.instance_id = "test-instance" + self.mock_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + def test_non_determinism_warning_creation(self): + """Test creating NonDeterminismWarning.""" + warning = NonDeterminismWarning("Test warning message") + assert str(warning) == "Test warning message" + assert issubclass(NonDeterminismWarning, UserWarning) + + def test_detector_creation(self): + """Test creating _NonDeterminismDetector.""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + assert detector.async_ctx is self.mock_ctx + assert detector.mode == "best_effort" + assert detector.detected_calls == set() + + def test_detector_context_manager_off_mode(self): + """Test detector context manager with off mode.""" + detector = _NonDeterminismDetector(self.mock_ctx, "off") + + with detector: + # Should not set up tracing in off mode + pass + + # Should complete without issues + + def test_detector_context_manager_best_effort_mode(self): + """Test detector context manager with best_effort mode.""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + import sys + + pre_trace = sys.gettrace() + with detector: + # Should set up tracing + original_trace = sys.gettrace() + assert original_trace is not pre_trace + + # After exit, original trace should be restored + assert sys.gettrace() is pre_trace + + def test_detector_trace_calls_detection(self): + """Test that detector can identify non-deterministic calls.""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + # Create a mock frame that looks like it's calling datetime.now + mock_frame = Mock() + mock_frame.f_code.co_filename = "/test/file.py" + mock_frame.f_code.co_name = "now" + mock_frame.f_locals = {"datetime": Mock(__module__="datetime")} + + # Test the frame checking logic + detector._check_frame_for_non_determinism(mock_frame) + + # Should detect the call (implementation may vary) + + def test_detector_strict_mode_raises_error(self): + """Test that detector raises error in strict mode.""" + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + # Create a mock frame for a non-deterministic call + mock_frame = Mock() + mock_frame.f_code.co_filename = "/test/file.py" + mock_frame.f_code.co_name = "random" + mock_frame.f_locals = {"random_module": Mock(__module__="random")} + + # Should raise error in strict mode when non-deterministic call detected + with pytest.raises(AsyncWorkflowError): + detector._handle_non_deterministic_call("random.random", mock_frame) + + def test_fast_map_random_whitelist_bound_self(self): + """random.* with deterministic bound self should be whitelisted in fast map.""" + # Prepare detector in strict (whitelist applies before error path) + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + class BoundSelf: + pass + + bs = BoundSelf() + bs._dt_deterministic = True + + frame = Mock() + frame.f_code.co_filename = "/test/rand.py" + frame.f_code.co_name = "random" # function name + frame.f_globals = {"__name__": "random"} + frame.f_locals = {"self": bs} + + # Should not raise or warn; returns early + detector._check_frame_for_non_determinism(frame) + + def test_fast_map_best_effort_warning_and_early_return(self): + """best_effort should warn once for fast-map hit (e.g., os.getenv) and return early.""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + frame = Mock() + frame.f_code.co_filename = "/test/osmod.py" + frame.f_code.co_name = "getenv" + frame.f_globals = {"__name__": "os"} + frame.f_locals = {} + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + detector._check_frame_for_non_determinism(frame) + assert any(issubclass(rec.category, NonDeterminismWarning) for rec in w) + + def test_fast_map_random_strict_raises_when_not_deterministic(self): + """random.* without deterministic bound self should trigger strict violation via fast map.""" + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + frame = Mock() + frame.f_code.co_filename = "/test/rand2.py" + frame.f_code.co_name = "randint" + frame.f_globals = {"__name__": "random"} + frame.f_locals = {"self": object()} # no _dt_deterministic + + with pytest.raises(AsyncWorkflowError): + detector._check_frame_for_non_determinism(frame) + + def test_detector_off_mode_no_tracing(self): + """Test detector in off mode doesn't set up tracing.""" + detector = _NonDeterminismDetector(self.mock_ctx, "off") + + import sys + + original_trace = sys.gettrace() + + with detector: + # Should not change trace function in off mode + assert sys.gettrace() is original_trace + + # Should still be the same after exit + assert sys.gettrace() is original_trace + + def test_detector_exception_in_globals_access(self): + """Test exception handling when accessing frame globals.""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + # Create a frame that raises exception when accessing f_globals + frame = Mock() + frame.f_code.co_filename = "/test/bad.py" + frame.f_code.co_name = "test_func" + frame.f_globals = Mock() + frame.f_globals.get.side_effect = Exception("globals access failed") + + # Should not raise, just handle gracefully + detector._check_frame_for_non_determinism(frame) + + def test_detector_exception_in_whitelist_check(self): + """Test exception handling in whitelist check.""" + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + frame = Mock() + frame.f_code.co_filename = "/test/rand3.py" + frame.f_code.co_name = "random" + frame.f_globals = {"__name__": "random"} + + # Create a self object that raises exception when accessing _dt_deterministic + class BadSelf: + @property + def _dt_deterministic(self): + raise Exception("attribute access failed") + + frame.f_locals = {"self": BadSelf()} + + # Should handle exception and continue to error path + with pytest.raises(AsyncWorkflowError): + detector._check_frame_for_non_determinism(frame) + + def test_detector_non_mapping_globals(self): + """Test handling of non-mapping f_globals.""" + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + frame = Mock() + frame.f_code.co_filename = "/test/bad_globals.py" + frame.f_code.co_name = "getenv" + frame.f_globals = "not a dict" # Non-mapping globals + frame.f_locals = {} + + # Should handle gracefully without raising + detector._check_frame_for_non_determinism(frame) + + def test_detector_exception_in_pattern_check(self): + """Test exception handling in pattern checking loop.""" + detector = _NonDeterminismDetector(self.mock_ctx, "strict") + + frame = Mock() + frame.f_code.co_filename = "/test/pattern.py" + frame.f_code.co_name = "time" + frame.f_globals = {"time.time": Mock(side_effect=Exception("access failed"))} + frame.f_locals = {} + + # Should handle exception and continue + detector._check_frame_for_non_determinism(frame) + + def test_detector_debug_mode_enabled(self): + """Test detector with debug mode enabled.""" + self.mock_ctx._debug_mode = True + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + frame = Mock() + frame.f_code.co_filename = "/test/debug.py" + frame.f_code.co_name = "now" + frame.f_lineno = 42 + + # Capture print output + import io + import sys + + captured_output = io.StringIO() + sys.stdout = captured_output + + try: + with pytest.warns(NonDeterminismWarning): + detector._handle_non_deterministic_call("datetime.now", frame) + output = captured_output.getvalue() + assert "[WORKFLOW DEBUG]" in output + assert "datetime.now" in output + finally: + sys.stdout = sys.__stdout__ + + def test_detector_noop_trace_method(self): + """Test _noop_trace method (line 56).""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + frame = Mock() + result = detector._noop_trace(frame, "call", None) + assert result is None + + def test_detector_trace_calls_non_call_event(self): + """Test _trace_calls with non-call event (lines 79-80).""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + frame = Mock() + + # Test with no original trace function + result = detector._trace_calls(frame, "return", None) + assert result is None + + # Test with original trace function + original_trace = Mock(return_value="original_result") + detector.original_trace_func = original_trace + result = detector._trace_calls(frame, "return", None) + assert result == "original_result" + original_trace.assert_called_once_with(frame, "return", None) + + def test_detector_trace_calls_with_original_func(self): + """Test _trace_calls returning original trace func (line 86).""" + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + frame = Mock() + frame.f_code.co_filename = "/test/safe.py" # Safe filename + frame.f_code.co_name = "safe_func" + frame.f_globals = {} + + # Test with original trace function + original_trace = Mock() + detector.original_trace_func = original_trace + result = detector._trace_calls(frame, "call", None) + assert result is original_trace + + +class TestSandboxScope: + """Test sandbox_scope functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.instance_id = "test-instance" + self.mock_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + self.mock_ctx.random.return_value = random.Random(12345) + self.mock_ctx.uuid4.return_value = uuid.UUID("12345678-1234-5678-1234-567812345678") + self.mock_ctx.now.return_value = datetime.datetime(2023, 1, 1, 12, 0, 0) + + # Add _base_ctx for sandbox patching + self.mock_ctx._base_ctx = Mock() + self.mock_ctx._base_ctx.create_timer = Mock(return_value=Mock()) + + # Ensure detection is not disabled + self.mock_ctx._detection_disabled = False + + def test_sandbox_scope_off_mode(self): + """Test sandbox_scope with off mode.""" + original_sleep = asyncio.sleep + original_random = random.random + + with _sandbox_scope(self.mock_ctx, "off"): + # Should not patch anything in off mode + assert asyncio.sleep is original_sleep + assert random.random is original_random + + def test_sandbox_scope_invalid_mode(self): + """Test sandbox_scope with invalid mode.""" + with pytest.raises(ValueError, match="Invalid sandbox mode"): + with _sandbox_scope(self.mock_ctx, "invalid_mode"): + pass + + def test_sandbox_scope_best_effort_patches(self): + """Test sandbox_scope patches functions in best_effort mode.""" + original_sleep = asyncio.sleep + original_random = random.random + original_uuid4 = uuid.uuid4 + original_time = time.time + + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Should patch functions + assert asyncio.sleep is not original_sleep + assert random.random is not original_random + assert uuid.uuid4 is not original_uuid4 + assert time.time is not original_time + + # Should restore originals + assert asyncio.sleep is original_sleep + assert random.random is original_random + assert uuid.uuid4 is original_uuid4 + assert time.time is original_time + + def test_sandbox_scope_strict_mode_blocks_dangerous_functions(self): + """Test sandbox_scope blocks dangerous functions in strict mode.""" + original_open = open + + with _sandbox_scope(self.mock_ctx, "strict"): + # Should block dangerous functions + with pytest.raises(AsyncWorkflowError, match="File I/O operations are not allowed"): + open("test.txt", "r") + + # Should restore original + assert open is original_open + + def test_strict_allows_ctx_random_methods_and_patched_global_random(self): + """Strict mode should allow ctx.random().randint and patched global random methods.""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "rng-ctx" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + with _sandbox_scope(async_ctx, "strict"): + # Allowed: via ctx.random() (detector should whitelist) + val1 = async_ctx.random().randint(1, 10) + assert isinstance(val1, int) + + # Also allowed: global random methods are patched deterministically in strict + val2 = random.randint(1, 10) + assert isinstance(val2, int) + + def test_strict_allows_all_deterministic_helpers(self): + """Strict mode should allow all ctx deterministic helpers without violations.""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "det-helpers" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + with _sandbox_scope(async_ctx, "strict"): + # now() + now_val = async_ctx.now() + assert isinstance(now_val, datetime.datetime) + + # uuid4() + uid = async_ctx.uuid4() + import uuid as _uuid + + assert isinstance(uid, _uuid.UUID) + + # random().random, randint, choice + rnd = async_ctx.random() + assert isinstance(rnd.random(), float) + assert isinstance(rnd.randint(1, 10), int) + assert isinstance(rnd.choice([1, 2, 3]), int) + + # random_string / random_int / random_choice + s = async_ctx.random_string(5) + assert isinstance(s, str) and len(s) == 5 + ri = async_ctx.random_int(1, 10) + assert isinstance(ri, int) + rc = async_ctx.random_choice(["a", "b"]) + assert rc in ["a", "b"] + + def test_sandbox_scope_patches_asyncio_sleep(self): + """Test that asyncio.sleep is properly patched within sandbox context.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Import asyncio within the sandbox context to get the patched version + import asyncio as sandboxed_asyncio + + # Call the patched sleep directly + patched_sleep_result = sandboxed_asyncio.sleep(1.0) + + # Should return our patched sleep awaitable + assert hasattr(patched_sleep_result, "__await__") + + # The awaitable should yield a timer task when awaited + awaitable_gen = patched_sleep_result.__await__() + try: + yielded_task = next(awaitable_gen) + # Should be the mock timer task + assert yielded_task is self.mock_ctx._base_ctx.create_timer.return_value + except StopIteration: + pass # Sleep completed immediately + + def test_sandbox_scope_patches_random_functions(self): + """Test that random functions are properly patched.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Should use deterministic random + val1 = random.random() + val2 = random.randint(1, 100) + val3 = random.randrange(10) + + assert isinstance(val1, float) + assert isinstance(val2, int) + assert isinstance(val3, int) + assert 1 <= val2 <= 100 + assert 0 <= val3 < 10 + + def test_sandbox_scope_patches_uuid4(self): + """Test that uuid.uuid4 is properly patched.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + test_uuid = uuid.uuid4() + assert isinstance(test_uuid, uuid.UUID) + assert test_uuid.version == 4 + + def test_sandbox_scope_patches_time_functions(self): + """Test that time functions are properly patched.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + current_time = time.time() + assert isinstance(current_time, float) + + if hasattr(time, "time_ns"): + current_time_ns = time.time_ns() + assert isinstance(current_time_ns, int) + + def test_patched_randrange_step_branch(self): + """Hit patched randrange path with step != 1 to cover the loop branch.""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "step-test" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + with _sandbox_scope(async_ctx, "best_effort"): + v = random.randrange(1, 10, 3) + assert 1 <= v < 10 and (v - 1) % 3 == 0 + + def test_sandbox_scope_strict_mode_blocks_os_urandom(self): + """Test that os.urandom is blocked in strict mode.""" + with _sandbox_scope(self.mock_ctx, "strict"): + with pytest.raises(AsyncWorkflowError, match="os.urandom is not allowed"): + os.urandom(16) + + def test_sandbox_scope_strict_mode_blocks_secrets(self): + """Test that secrets module is blocked in strict mode.""" + with _sandbox_scope(self.mock_ctx, "strict"): + with pytest.raises(AsyncWorkflowError, match="secrets module is not allowed"): + secrets.token_bytes(16) + + with pytest.raises(AsyncWorkflowError, match="secrets module is not allowed"): + secrets.token_hex(16) + + def test_sandbox_scope_strict_mode_blocks_asyncio_create_task(self): + """Test that asyncio.create_task is blocked in strict mode.""" + + async def dummy_coro(): + return "test" + + with _sandbox_scope(self.mock_ctx, "strict"): + with pytest.raises(AsyncWorkflowError, match="asyncio.create_task is not allowed"): + asyncio.create_task(dummy_coro()) + + @pytest.mark.asyncio + async def test_asyncio_sleep_zero_passthrough(self): + """sleep(0) should use original asyncio.sleep (passthrough branch).""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "sleep-zero" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + with _sandbox_scope(async_ctx, "best_effort"): + # Should not raise; executes passthrough branch in patched_sleep + await asyncio.sleep(0) + + def test_strict_restores_os_and_secrets_on_exit(self): + """Ensure strict mode restores os.urandom and secrets functions on exit.""" + orig_urandom = getattr(os, "urandom", None) + orig_token_bytes = getattr(secrets, "token_bytes", None) + orig_token_hex = getattr(secrets, "token_hex", None) + + with _sandbox_scope(self.mock_ctx, "strict"): + if orig_urandom is not None: + with pytest.raises(AsyncWorkflowError): + os.urandom(1) + if orig_token_bytes is not None: + with pytest.raises(AsyncWorkflowError): + secrets.token_bytes(1) + if orig_token_hex is not None: + with pytest.raises(AsyncWorkflowError): + secrets.token_hex(1) + + # After exit, originals should be restored + if orig_urandom is not None: + assert os.urandom is orig_urandom + if orig_token_bytes is not None: + assert secrets.token_bytes is orig_token_bytes + if orig_token_hex is not None: + assert secrets.token_hex is orig_token_hex + + @pytest.mark.asyncio + async def test_empty_gather_caching_replay(self): + """Empty gather should be awaitable and replay cached result on repeated awaits.""" + from durabletask.aio import AsyncWorkflowContext + + mock_base_ctx = Mock() + mock_base_ctx.instance_id = "gather-cache" + mock_base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(mock_base_ctx) + with _sandbox_scope(async_ctx, "best_effort"): + g0 = asyncio.gather() + r0a = await g0 + r0b = await g0 + assert r0a == [] and r0b == [] + + def test_patched_datetime_now_with_tz(self): + """datetime.now(tz=UTC) should return aware UTC when patched.""" + from datetime import timezone + + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "tz-test" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + with _sandbox_scope(async_ctx, "best_effort"): + now_utc = datetime.datetime.now(tz=timezone.utc) + assert now_utc.tzinfo is timezone.utc + + @pytest.mark.asyncio + async def test_create_task_allowed_in_best_effort(self): + """In best_effort mode, create_task should be allowed and runnable.""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "best-effort-ct" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + async def quick(): + await asyncio.sleep(0) + return "ok" + + with _sandbox_scope(async_ctx, "best_effort"): + t = asyncio.create_task(quick()) + assert await t == "ok" + + @pytest.mark.asyncio + async def test_create_task_blocked_strict_no_unawaited_warning(self): + """Strict mode: ensure blocked coroutine is closed (no 'never awaited' warnings).""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "strict-ct" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + async def dummy(): + return 1 + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with pytest.raises(AsyncWorkflowError): + with _sandbox_scope(async_ctx, "strict"): + asyncio.create_task(dummy()) + assert not any("was never awaited" in str(rec.message) for rec in w) + + @pytest.mark.asyncio + async def test_env_disable_detection_allows_create_task(self): + """DAPR_WF_DISABLE_DETERMINISTIC_DETECTION=true forces mode off; create_task allowed.""" + import durabletask.aio.sandbox as sandbox_module + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "env-off" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + async def quick(): + await asyncio.sleep(0) + return "ok" + + # Mock the module-level constant to simulate environment variable set + with patch.object(sandbox_module, "_DISABLE_DETECTION", True): + with _sandbox_scope(async_ctx, "strict"): + t = asyncio.create_task(quick()) + assert await t == "ok" + + def test_sandbox_scope_global_disable_env_var(self): + """Test that DAPR_WF_DISABLE_DETERMINISTIC_DETECTION environment variable works.""" + import durabletask.aio.sandbox as sandbox_module + + original_random = random.random + + # Mock the module-level constant to simulate environment variable set + with patch.object(sandbox_module, "_DISABLE_DETECTION", True): + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Should not patch when globally disabled + assert random.random is original_random + + def test_sandbox_scope_context_detection_disabled(self): + """Test that context-level detection disable works.""" + self.mock_ctx._detection_disabled = True + original_random = random.random + + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Should not patch when disabled on context + assert random.random is original_random + + def test_rng_context_fallback_to_base_ctx(self): + """Sandbox should fall back to _base_ctx.instance_id/current_utc_datetime when missing on async_ctx. + + Same context twice -> identical deterministic sequence + Change only instance_id -> different sequence + Change only current_utc_datetime -> different sequence + """ + + class MinimalCtx: + pass + + fallback = MinimalCtx() + fallback._base_ctx = Mock() + fallback._base_ctx.instance_id = "fallback-instance" + fallback._base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + # Ensure MinimalCtx lacks direct attributes + assert not hasattr(fallback, "instance_id") + assert not hasattr(fallback, "current_utc_datetime") + assert not hasattr(fallback, "now") + + # Same fallback context twice -> identical deterministic sequence + with _sandbox_scope(fallback, "best_effort"): + seq1 = [random.random() for _ in range(3)] + with _sandbox_scope(fallback, "best_effort"): + seq2 = [random.random() for _ in range(3)] + assert seq1 == seq2 + + # Change only instance_id -> different sequence + fallback_id = MinimalCtx() + fallback_id._base_ctx = Mock() + fallback_id._base_ctx.instance_id = "fallback-instance-2" + fallback_id._base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + with _sandbox_scope(fallback_id, "best_effort"): + seq_id = [random.random() for _ in range(3)] + assert seq_id != seq1 + + # Change only current_utc_datetime -> different sequence + fallback_time = MinimalCtx() + fallback_time._base_ctx = Mock() + fallback_time._base_ctx.instance_id = "fallback-instance" + fallback_time._base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 1) + with _sandbox_scope(fallback_time, "best_effort"): + seq_time = [random.random() for _ in range(3)] + assert seq_time != seq1 + + def test_sandbox_scope_nested_contexts(self): + """Test nested sandbox contexts.""" + original_random = random.random + + with _sandbox_scope(self.mock_ctx, "best_effort"): + patched_random_1 = random.random + assert patched_random_1 is not original_random + + with _sandbox_scope(self.mock_ctx, "strict"): + patched_random_2 = random.random + # Should be patched differently or same + assert patched_random_2 is not original_random + + # Should restore to first patch level + assert random.random is patched_random_1 + + # Should restore to original + assert random.random is original_random + + def test_sandbox_scope_exception_handling(self): + """Test that sandbox properly restores functions even if exception occurs.""" + original_random = random.random + + try: + with _sandbox_scope(self.mock_ctx, "best_effort"): + assert random.random is not original_random + raise ValueError("Test exception") + except ValueError: + pass + + # Should still restore original even after exception + assert random.random is original_random + + def test_sandbox_scope_deterministic_behavior(self): + """Test that sandbox provides deterministic behavior.""" + results1 = [] + results2 = [] + + # First run + with _sandbox_scope(self.mock_ctx, "best_effort"): + results1.append(random.random()) + results1.append(random.randint(1, 100)) + results1.append(str(uuid.uuid4())) + results1.append(time.time()) + + # Second run with same context + with _sandbox_scope(self.mock_ctx, "best_effort"): + results2.append(random.random()) + results2.append(random.randint(1, 100)) + results2.append(str(uuid.uuid4())) + results2.append(time.time()) + + # Should be deterministic (same results) + assert results1 == results2 + + def test_sandbox_scope_different_contexts_different_results(self): + """Test that different contexts produce different results.""" + mock_ctx2 = Mock() + mock_ctx2.instance_id = "different-instance" + mock_ctx2.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_ctx2.random.return_value = random.Random(54321) + mock_ctx2.uuid4.return_value = uuid.UUID("87654321-4321-8765-4321-876543218765") + mock_ctx2.now.return_value = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_ctx2._detection_disabled = False + + results1 = [] + results2 = [] + + # First context + with _sandbox_scope(self.mock_ctx, "best_effort"): + results1.append(random.random()) + results1.append(str(uuid.uuid4())) + + # Different context + with _sandbox_scope(mock_ctx2, "best_effort"): + results2.append(random.random()) + results2.append(str(uuid.uuid4())) + + # Should be different + assert results1 != results2 + + def test_sandbox_missing_context_attributes(self): + """Test sandbox with context missing various attributes.""" + + # Create context with missing attributes but proper fallbacks + minimal_ctx = Mock() + minimal_ctx._detection_disabled = False + minimal_ctx.instance_id = None # Will use empty string fallback + minimal_ctx._base_ctx = None # No base context + # Mock now() to return proper datetime + minimal_ctx.now = Mock(return_value=datetime.datetime(2023, 1, 1, 12, 0, 0)) + minimal_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + with _sandbox_scope(minimal_ctx, "best_effort"): + # Should use fallback values + val = random.random() + assert isinstance(val, float) + + def test_sandbox_context_with_now_exception(self): + """Test sandbox when ctx.now() raises exception.""" + + ctx = Mock() + ctx._detection_disabled = False + ctx.instance_id = "test" + ctx.now = Mock(side_effect=Exception("now() failed")) + ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + with _sandbox_scope(ctx, "best_effort"): + # Should fallback to current_utc_datetime + val = random.random() + assert isinstance(val, float) + + def test_sandbox_context_missing_base_ctx(self): + """Test sandbox with context missing _base_ctx.""" + ctx = Mock() + ctx._detection_disabled = False + ctx.instance_id = None # No instance_id + ctx._base_ctx = None # No _base_ctx + ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + # Mock now() to return proper datetime + ctx.now = Mock(return_value=datetime.datetime(2023, 1, 1, 12, 0, 0)) + + with _sandbox_scope(ctx, "best_effort"): + # Should use empty string fallback for instance_id + val = random.random() + assert isinstance(val, float) + + def test_sandbox_rng_setattr_exception(self): + """Test sandbox when setattr on rng fails.""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "test" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + # Mock deterministic_random to return an object that can't be modified + with patch("durabletask.aio.sandbox.deterministic_random") as mock_rng: + # Create a class that raises exception on setattr + class ImmutableRNG: + def __setattr__(self, name, value): + if name == "_dt_deterministic": + raise Exception("setattr failed") + super().__setattr__(name, value) + + def random(self): + return 0.5 + + mock_rng.return_value = ImmutableRNG() + + with _sandbox_scope(async_ctx, "best_effort"): + # Should handle setattr exception gracefully + val = random.random() + assert isinstance(val, float) + + def test_sandbox_missing_time_ns(self): + """Test sandbox when time.time_ns is not available.""" + import time as time_mod + + # Temporarily remove time_ns if it exists + original_time_ns = getattr(time_mod, "time_ns", None) + if hasattr(time_mod, "time_ns"): + delattr(time_mod, "time_ns") + + try: + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Should work without time_ns + val = time_mod.time() + assert isinstance(val, float) + finally: + # Restore time_ns if it existed + if original_time_ns is not None: + time_mod.time_ns = original_time_ns + + def test_sandbox_missing_optional_functions(self): + """Test sandbox with missing optional functions.""" + import os + import secrets + + # Temporarily remove optional functions + original_urandom = getattr(os, "urandom", None) + original_token_bytes = getattr(secrets, "token_bytes", None) + original_token_hex = getattr(secrets, "token_hex", None) + + if hasattr(os, "urandom"): + delattr(os, "urandom") + if hasattr(secrets, "token_bytes"): + delattr(secrets, "token_bytes") + if hasattr(secrets, "token_hex"): + delattr(secrets, "token_hex") + + try: + with _sandbox_scope(self.mock_ctx, "strict"): + # Should work without the optional functions + val = random.random() + assert isinstance(val, float) + finally: + # Restore functions + if original_urandom is not None: + os.urandom = original_urandom + if original_token_bytes is not None: + secrets.token_bytes = original_token_bytes + if original_token_hex is not None: + secrets.token_hex = original_token_hex + + def test_sandbox_restore_missing_optional_functions(self): + """Test sandbox restore with missing optional functions.""" + import os + import secrets + + # Remove optional functions before entering sandbox + original_urandom = getattr(os, "urandom", None) + original_token_bytes = getattr(secrets, "token_bytes", None) + original_token_hex = getattr(secrets, "token_hex", None) + + if hasattr(os, "urandom"): + delattr(os, "urandom") + if hasattr(secrets, "token_bytes"): + delattr(secrets, "token_bytes") + if hasattr(secrets, "token_hex"): + delattr(secrets, "token_hex") + + try: + with _sandbox_scope(self.mock_ctx, "strict"): + val = random.random() + assert isinstance(val, float) + # Should exit cleanly even with missing functions + finally: + # Restore functions + if original_urandom is not None: + os.urandom = original_urandom + if original_token_bytes is not None: + secrets.token_bytes = original_token_bytes + if original_token_hex is not None: + secrets.token_hex = original_token_hex + + def test_sandbox_patched_sleep_with_base_ctx(self): + """Test patched sleep accessing _base_ctx (lines 325-343).""" + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "sleep-test" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + base_ctx.create_timer = Mock(return_value=Mock()) + + async_ctx = AsyncWorkflowContext(base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Test positive delay - should use patched version + sleep_awaitable = asyncio.sleep(1.0) + assert hasattr(sleep_awaitable, "__await__") + + # Actually await it to avoid the warning + # The mock should make this complete immediately + try: + gen = sleep_awaitable.__await__() + next(gen) + except StopIteration: + pass # Expected when mock completes immediately + + # Test zero delay - should use original (passthrough) + zero_sleep = asyncio.sleep(0) + # This should be the original coroutine or our awaitable + assert hasattr(zero_sleep, "__await__") + + # Actually await it to avoid the warning + # The mock should make this complete immediately + try: + gen = zero_sleep.__await__() + next(gen) + except StopIteration: + pass # Expected when mock completes immediately + + def test_sandbox_strict_blocking_functions_coverage(self): + """Test strict mode blocking functions to hit lines 588-615.""" + import builtins + import os + import secrets + + with _sandbox_scope(self.mock_ctx, "strict"): + # Test blocked open function (lines 588-593) + with pytest.raises(AsyncWorkflowError, match="File I/O operations are not allowed"): + builtins.open("test.txt", "r") + + # Test blocked os.urandom (lines 595-600) - if available + if hasattr(os, "urandom"): + with pytest.raises(AsyncWorkflowError, match="os.urandom is not allowed"): + os.urandom(16) + + # Test blocked secrets functions (lines 602-607) - if available + if hasattr(secrets, "token_bytes"): + with pytest.raises(AsyncWorkflowError, match="secrets module is not allowed"): + secrets.token_bytes(16) + + if hasattr(secrets, "token_hex"): + with pytest.raises(AsyncWorkflowError, match="secrets module is not allowed"): + secrets.token_hex(16) + + def test_sandbox_restore_with_gather_and_create_task(self): + """Test restore functions with gather and create_task (lines 624-628).""" + import asyncio + + original_gather = asyncio.gather + original_create_task = getattr(asyncio, "create_task", None) + + with _sandbox_scope(self.mock_ctx, "best_effort"): + # gather should be patched in best_effort + assert asyncio.gather is not original_gather + # create_task is only patched in strict mode, not best_effort + + # Should be restored + assert asyncio.gather is original_gather + + # Test strict mode where create_task is also patched + with _sandbox_scope(self.mock_ctx, "strict"): + assert asyncio.gather is not original_gather + if original_create_task is not None: + assert asyncio.create_task is not original_create_task + + # Should be restored after strict mode too + assert asyncio.gather is original_gather + if original_create_task is not None: + assert asyncio.create_task is original_create_task + + def test_sandbox_best_effort_debug_mode_tracing(self): + """Test best_effort mode with debug mode enabled for full tracing (line 61).""" + self.mock_ctx._debug_mode = True + + import sys + + original_trace = sys.gettrace() + + detector = _NonDeterminismDetector(self.mock_ctx, "best_effort") + + with detector: + # Should set up full tracing in debug mode + current_trace = sys.gettrace() + assert current_trace is not original_trace + assert current_trace is not detector._noop_trace + + # Should restore original trace + assert sys.gettrace() is original_trace + + def test_sandbox_detector_exit_branch_coverage(self): + """Test detector __exit__ method branch (line 74).""" + detector = _NonDeterminismDetector(self.mock_ctx, "off") + + # In off mode, __exit__ should not restore trace function + import sys + + original_trace = sys.gettrace() + + with detector: + pass # off mode doesn't change trace + + # Should still be the same + assert sys.gettrace() is original_trace + + def test_sandbox_context_no_current_utc_datetime(self): + """Test sandbox with context missing current_utc_datetime (lines 358-364).""" + + # Create a minimal context object without current_utc_datetime + class MinimalCtx: + def __init__(self): + self._detection_disabled = False + self.instance_id = "test" + self._base_ctx = None + + def now(self): + raise Exception("now() failed") + + ctx = MinimalCtx() + + with _sandbox_scope(ctx, "best_effort"): + # Should use epoch fallback (line 364) + val = random.random() + assert isinstance(val, float) + + +class TestGatherMixedOptimization: + """Tests for mixed workflow/native awaitables optimization in patched gather.""" + + @pytest.mark.asyncio + async def test_mixed_groups_preserve_order_and_use_when_all(self, monkeypatch): + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.awaitables import AwaitableBase as _AwaitableBase + + # Create async context and enable sandbox + base_ctx = Mock() + base_ctx.instance_id = "mix-test" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + + async_ctx = AsyncWorkflowContext(base_ctx) + + # Dummy workflow awaitable that should be batched into WhenAllAwaitable and not awaited individually + class DummyWF(_AwaitableBase[str]): + def _to_task(self): + # Would normally convert to a durable task; not needed in this test + return Mock(spec=dt_task.Task) + + # Patch WhenAllAwaitable to a fast fake that returns predictable results + recorded_items: list[list[object]] = [] + + class FakeWhenAll: + def __init__(self, items): + recorded_items.append(list(items)) + self._items = list(items) + + def __await__(self): + async def _coro(): + # Return results per-item deterministically + return [f"W{i}" for i, _ in enumerate(self._items)] + + return _coro().__await__() + + monkeypatch.setattr("durabletask.aio.awaitables.WhenAllAwaitable", FakeWhenAll) + + # Native coroutines + async def native(i: int): + await asyncio.sleep(0) + return f"N{i}" + + with _sandbox_scope(async_ctx, "best_effort"): + out = await asyncio.gather(DummyWF(), native(0), DummyWF(), native(1)) + + # Order preserved and batched results merged back correctly + assert out == ["W0", "N0", "W1", "N1"] + # Ensure WhenAll got only workflow awaitables (2 items) + assert recorded_items and len(recorded_items[0]) == 2 + + +class TestNowWithSequence: + """Tests for AsyncWorkflowContext.now_with_sequence.""" + + def test_now_with_sequence_increments(self): + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "test-seq" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + base_ctx.is_replaying = False + async_ctx = AsyncWorkflowContext(base_ctx) + + # Each call should increment by 1 microsecond + t1 = async_ctx.now_with_sequence() + t2 = async_ctx.now_with_sequence() + t3 = async_ctx.now_with_sequence() + + assert t1 == datetime.datetime(2023, 1, 1, 12, 0, 0, 0) + assert t2 == datetime.datetime(2023, 1, 1, 12, 0, 0, 1) + assert t3 == datetime.datetime(2023, 1, 1, 12, 0, 0, 2) + assert t1 < t2 < t3 + + def test_now_with_sequence_deterministic_on_replay(self): + from durabletask.aio import AsyncWorkflowContext + + # First execution + base_ctx1 = Mock() + base_ctx1.instance_id = "test-replay" + base_ctx1.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + base_ctx1.is_replaying = False + ctx1 = AsyncWorkflowContext(base_ctx1) + + t1_first = ctx1.now_with_sequence() + t2_first = ctx1.now_with_sequence() + + # Replay - counter resets (new context instance) + base_ctx2 = Mock() + base_ctx2.instance_id = "test-replay" + base_ctx2.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + base_ctx2.is_replaying = True + ctx2 = AsyncWorkflowContext(base_ctx2) + + t1_replay = ctx2.now_with_sequence() + t2_replay = ctx2.now_with_sequence() + + # Should produce identical timestamps (deterministic) + assert t1_first == t1_replay + assert t2_first == t2_replay + + def test_now_with_sequence_works_in_strict_mode(self): + from durabletask.aio import AsyncWorkflowContext + + base_ctx = Mock() + base_ctx.instance_id = "test-strict" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + base_ctx.is_replaying = False + async_ctx = AsyncWorkflowContext(base_ctx) + + # Should work fine in strict sandbox mode (deterministic) + with _sandbox_scope(async_ctx, "strict"): + t1 = async_ctx.now_with_sequence() + t2 = async_ctx.now_with_sequence() + assert t1 < t2 + + @pytest.mark.asyncio + async def test_mixed_groups_return_exceptions_true(self, monkeypatch): + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.awaitables import AwaitableBase as _AwaitableBase + + base_ctx = Mock() + base_ctx.instance_id = "mix-exc" + base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(base_ctx) + + class DummyWF(_AwaitableBase[str]): + def _to_task(self): + return Mock(spec=dt_task.Task) + + # Fake WhenAll that returns values, simulating exception swallowing already applied + class FakeWhenAll: + def __init__(self, items): + self._items = list(items) + + def __await__(self): + async def _coro(): + # Return placeholders for each workflow item + return ["W_OK" for _ in self._items] + + return _coro().__await__() + + monkeypatch.setattr("durabletask.aio.awaitables.WhenAllAwaitable", FakeWhenAll) + + async def native_ok(): + return "N_OK" + + async def native_fail(): + raise RuntimeError("boom") + + with _sandbox_scope(async_ctx, "best_effort"): + res = await asyncio.gather( + DummyWF(), native_fail(), native_ok(), return_exceptions=True + ) + + assert res[0] == "W_OK" + assert isinstance(res[1], RuntimeError) + assert res[2] == "N_OK" + + def test_sandbox_scope_asyncio_gather_patching(self): + """Test that asyncio.gather is properly patched.""" + + async def test_task(): + return "test" + + # Capture original gather before entering sandbox + original_gather = asyncio.gather + from durabletask.aio import AsyncWorkflowContext + + mock_base_ctx = Mock() + mock_base_ctx.instance_id = "gather-patch" + mock_base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(mock_base_ctx) + with _sandbox_scope(async_ctx, "best_effort"): + # Should patch gather + assert asyncio.gather is not original_gather + + # Test empty gather + empty_gather = asyncio.gather() + assert hasattr(empty_gather, "__await__") + + def test_sandbox_scope_workflow_awaitables_detection(self): + """Test that sandbox can detect workflow awaitables.""" + from durabletask import task as dt_task + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.awaitables import ActivityAwaitable + + # Create a mock activity awaitable + mock_task = Mock(spec=dt_task.Task) + mock_base_ctx = Mock() + mock_base_ctx.call_activity.return_value = mock_task + mock_base_ctx.instance_id = "detect-wf" + mock_base_ctx.current_utc_datetime = datetime.datetime(2023, 1, 1, 12, 0, 0) + async_ctx = AsyncWorkflowContext(mock_base_ctx) + + activity_awaitable = ActivityAwaitable(mock_base_ctx, lambda: "test", input="test") + + with _sandbox_scope(async_ctx, "best_effort"): + # Should recognize workflow awaitables + gather_result = asyncio.gather(activity_awaitable) + assert hasattr(gather_result, "__await__") + + +class TestPatchedFunctionImplementations: + """Test that patched deterministic functions work correctly.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.instance_id = "test-instance" + self.mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + def test_patched_random_functions(self): + """Test all patched random functions produce deterministic results.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Test random() + r1 = random.random() + assert isinstance(r1, float) + assert 0 <= r1 < 1 + + # Test randint() + ri = random.randint(1, 100) + assert isinstance(ri, int) + assert 1 <= ri <= 100 + + # Test getrandbits() + rb = random.getrandbits(8) + assert isinstance(rb, int) + assert 0 <= rb < 256 + + # Test randrange() with step + rr = random.randrange(0, 100, 5) + assert isinstance(rr, int) + assert 0 <= rr < 100 + assert rr % 5 == 0 + + # Test randrange() single arg + rr_single = random.randrange(50) + assert isinstance(rr_single, int) + assert 0 <= rr_single < 50 + + def test_patched_time_functions(self): + """Test patched time functions return deterministic values.""" + with _sandbox_scope(self.mock_ctx, "best_effort"): + t = time.time() + assert isinstance(t, float) + assert t > 0 + + # time_ns if available + if hasattr(time, "time_ns"): + tn = time.time_ns() + assert isinstance(tn, int) + assert tn > 0 + + def test_patched_datetime_now_with_timezone(self): + """Test patched datetime.now() with timezone argument.""" + import datetime as dt + + with _sandbox_scope(self.mock_ctx, "best_effort"): + # With timezone should still work + tz = dt.timezone.utc + now_tz = dt.datetime.now(tz) + assert isinstance(now_tz, dt.datetime) + assert now_tz.tzinfo is not None + + +class TestAsyncioSleepEdgeCases: + """Test asyncio.sleep patching edge cases.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance" + self.mock_base_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + self.mock_base_ctx.create_timer = Mock() + + def test_asyncio_sleep_zero_delay_passthrough(self): + """Test that zero delay passes through to original asyncio.sleep.""" + from durabletask.aio import AsyncWorkflowContext + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Zero delay should pass through + result = asyncio.sleep(0) + # Should be a coroutine from original asyncio.sleep + assert asyncio.iscoroutine(result) + result.close() # Clean up + + def test_asyncio_sleep_negative_delay_passthrough(self): + """Test that negative delay passes through to original asyncio.sleep.""" + from durabletask.aio import AsyncWorkflowContext + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Negative delay should pass through + result = asyncio.sleep(-1) + assert asyncio.iscoroutine(result) + result.close() # Clean up + + def test_asyncio_sleep_positive_delay_uses_timer(self): + """Test that positive delay uses create_timer.""" + from durabletask.aio import AsyncWorkflowContext + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Positive delay should create patched awaitable + result = asyncio.sleep(5) + # Should have __await__ method + assert hasattr(result, "__await__") + + def test_asyncio_sleep_invalid_delay(self): + """Test asyncio.sleep with invalid delay value.""" + from durabletask.aio import AsyncWorkflowContext + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + with _sandbox_scope(async_ctx, "best_effort"): + # Invalid delay should still work (fallthrough to patched awaitable) + result = asyncio.sleep("invalid") + assert hasattr(result, "__await__") + + +class TestRNGContextFallbacks: + """Test RNG initialization with missing context attributes.""" + + def test_rng_missing_instance_id(self): + """Test RNG initialization when instance_id is missing.""" + mock_ctx = Mock() + # No instance_id attribute + mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + with _sandbox_scope(mock_ctx, "best_effort"): + # Should use fallback and still work + r = random.random() + assert isinstance(r, float) + + def test_rng_missing_base_ctx_instance_id(self): + """Test RNG with no instance_id on main or base context.""" + mock_ctx = Mock() + mock_ctx._base_ctx = Mock() + # Neither has instance_id + mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + with _sandbox_scope(mock_ctx, "best_effort"): + r = random.random() + assert isinstance(r, float) + + def test_rng_now_method_exception(self): + """Test RNG when now() method raises exception.""" + mock_ctx = Mock() + mock_ctx.instance_id = "test" + mock_ctx.now = Mock(side_effect=Exception("now() failed")) + mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + with _sandbox_scope(mock_ctx, "best_effort"): + # Should fall back to current_utc_datetime + r = random.random() + assert isinstance(r, float) + + def test_rng_missing_current_utc_datetime(self): + """Test RNG when current_utc_datetime is missing.""" + mock_ctx = Mock(spec=[]) # No attributes + mock_ctx.instance_id = "test" + + with _sandbox_scope(mock_ctx, "best_effort"): + # Should use epoch fallback + r = random.random() + assert isinstance(r, float) + + def test_rng_base_ctx_current_utc_datetime(self): + """Test RNG uses base_ctx.current_utc_datetime as fallback.""" + mock_ctx = Mock(spec=["instance_id", "_base_ctx"]) + mock_ctx.instance_id = "test" + mock_ctx._base_ctx = Mock() + mock_ctx._base_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + with _sandbox_scope(mock_ctx, "best_effort"): + r = random.random() + assert isinstance(r, float) + + def test_rng_setattr_exception_handling(self): + """Test RNG handles setattr exception gracefully.""" + + class ReadOnlyRNG: + def __setattr__(self, name, value): + raise AttributeError("Cannot set attribute") + + mock_ctx = Mock() + mock_ctx.instance_id = "test" + mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + # Should not crash even if setattr fails + with _sandbox_scope(mock_ctx, "best_effort"): + r = random.random() + assert isinstance(r, float) + + +class TestSandboxLifecycle: + """Test _Sandbox class lifecycle and patch/restore operations.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_ctx = Mock() + self.mock_ctx.instance_id = "test-instance" + self.mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + def test_sandbox_lifecycle_doesnt_crash(self): + """Test that sandbox lifecycle operations don't crash.""" + # Just verify the sandbox can be entered and exited without errors + with _sandbox_scope(self.mock_ctx, "best_effort"): + # Use a random function + r = random.random() + assert isinstance(r, float) + + # Verify no issues with nested contexts + with _sandbox_scope(self.mock_ctx, "best_effort"): + with _sandbox_scope(self.mock_ctx, "strict"): + r = random.random() + assert isinstance(r, float) + + # Verify exception doesn't break cleanup + try: + with _sandbox_scope(self.mock_ctx, "best_effort"): + raise ValueError("Test") + except ValueError: + pass + + def test_sandbox_restores_optional_missing_functions(self): + """Test sandbox handles missing optional functions during restore.""" + mock_ctx = Mock() + mock_ctx.instance_id = "test" + mock_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + + # Test with time_ns potentially missing + with _sandbox_scope(mock_ctx, "best_effort"): + # Should handle gracefully whether time_ns exists or not + pass + + # Should not crash during restore + + +class TestPatchedFunctionsInWorkflow: + """Test that patched functions are actually executed in workflow context.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext) + self.mock_base_ctx.instance_id = "test-instance-123" + self.mock_base_ctx.current_utc_datetime = datetime.datetime(2025, 1, 1, 12, 0, 0) + self.mock_base_ctx.create_timer = Mock() + self.mock_base_ctx.call_activity = Mock() + + # Ensure now() method exists and returns datetime + def mock_now(): + return datetime.datetime(2025, 1, 1, 12, 0, 0) + + self.mock_base_ctx.now = mock_now + + @pytest.mark.asyncio + async def test_workflow_calls_random_functions(self): + """Test workflow that calls random functions within sandbox.""" + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.driver import CoroutineOrchestratorRunner + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def workflow_with_random(ctx): + # Call various random functions + r = random.random() + ri = random.randint(1, 100) + rb = random.getrandbits(8) + rr = random.randrange(10, 50, 5) + rr2 = random.randrange(20) + return [r, ri, rb, rr, rr2] + + runner = CoroutineOrchestratorRunner(workflow_with_random, sandbox_mode="best_effort") + + # Generate and drive + gen = runner.to_generator(async_ctx) + try: + next(gen) + except StopIteration as e: + result = e.value + assert isinstance(result, list) + assert len(result) == 5 + assert isinstance(result[0], float) + assert isinstance(result[1], int) + assert isinstance(result[2], int) + assert isinstance(result[3], int) + assert isinstance(result[4], int) + + @pytest.mark.asyncio + async def test_workflow_calls_uuid4(self): + """Test workflow that calls uuid4 within sandbox.""" + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.driver import CoroutineOrchestratorRunner + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def workflow_with_uuid(ctx): + u1 = uuid.uuid4() + u2 = uuid.uuid4() + return [u1, u2] + + runner = CoroutineOrchestratorRunner(workflow_with_uuid, sandbox_mode="best_effort") + + # Generate and drive + gen = runner.to_generator(async_ctx) + try: + next(gen) + except StopIteration as e: + result = e.value + assert isinstance(result, list) + assert len(result) == 2 + assert isinstance(result[0], uuid.UUID) + assert isinstance(result[1], uuid.UUID) + + @pytest.mark.asyncio + async def test_workflow_calls_time_functions(self): + """Test workflow that calls time functions within sandbox.""" + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.driver import CoroutineOrchestratorRunner + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def workflow_with_time(ctx): + t = time.time() + results = [t] + if hasattr(time, "time_ns"): + tn = time.time_ns() + results.append(tn) + return results + + runner = CoroutineOrchestratorRunner(workflow_with_time, sandbox_mode="best_effort") + + # Generate and drive + gen = runner.to_generator(async_ctx) + try: + next(gen) + except StopIteration as e: + result = e.value + assert isinstance(result, list) + assert len(result) >= 1 + assert isinstance(result[0], float) + + @pytest.mark.asyncio + async def test_workflow_calls_datetime_functions(self): + """Test workflow that calls datetime functions within sandbox.""" + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.driver import CoroutineOrchestratorRunner + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def workflow_with_datetime(ctx): + now = datetime.datetime.now() + utcnow = datetime.datetime.now(datetime.timezone.utc) + now_tz = datetime.datetime.now(datetime.timezone.utc) + return [now, utcnow, now_tz] + + runner = CoroutineOrchestratorRunner(workflow_with_datetime, sandbox_mode="best_effort") + + # Generate and drive + gen = runner.to_generator(async_ctx) + try: + next(gen) + except StopIteration as e: + result = e.value + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(d, datetime.datetime) for d in result) + + @pytest.mark.asyncio + async def test_workflow_calls_all_random_variants(self): + """Test workflow that exercises all random function variants.""" + from durabletask.aio import AsyncWorkflowContext + from durabletask.aio.driver import CoroutineOrchestratorRunner + + async_ctx = AsyncWorkflowContext(self.mock_base_ctx) + + async def workflow_comprehensive(ctx): + results = {} + # Test all random variants + results["random"] = random.random() + results["randint"] = random.randint(50, 100) + results["getrandbits"] = random.getrandbits(16) + results["randrange_single"] = random.randrange(50) + results["randrange_two"] = random.randrange(10, 50) + results["randrange_step"] = random.randrange(0, 100, 5) + + # Test uuid + results["uuid4"] = str(uuid.uuid4()) + + # Test time + results["time"] = time.time() + if hasattr(time, "time_ns"): + results["time_ns"] = time.time_ns() + + # Test datetime + results["now"] = datetime.datetime.now() + results["utcnow"] = datetime.datetime.now(datetime.timezone.utc) + results["now_tz"] = datetime.datetime.now(datetime.timezone.utc) + + return results + + runner = CoroutineOrchestratorRunner(workflow_comprehensive, sandbox_mode="best_effort") + + # Generate and drive + gen = runner.to_generator(async_ctx) + try: + next(gen) + except StopIteration as e: + result = e.value + assert isinstance(result, dict) + assert "random" in result + assert "uuid4" in result + assert "time" in result diff --git a/tests/aio/test_worker_concurrency_loop_async.py b/tests/aio/test_worker_concurrency_loop_async.py new file mode 100644 index 0000000..8bc6e7d --- /dev/null +++ b/tests/aio/test_worker_concurrency_loop_async.py @@ -0,0 +1,99 @@ +# Copyright 2025 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker + + +class DummyStub: + def __init__(self): + self.completed = [] + + def CompleteOrchestratorTask(self, res): + self.completed.append(("orchestrator", res)) + + def CompleteActivityTask(self, res): + self.completed.append(("activity", res)) + + +class DummyRequest: + def __init__(self, kind, instance_id): + self.kind = kind + self.instanceId = instance_id + self.orchestrationInstance = type("O", (), {"instanceId": instance_id}) + self.name = "dummy" + self.taskId = 1 + self.input = type("I", (), {"value": ""}) + self.pastEvents = [] + self.newEvents = [] + + def HasField(self, field): + return (field == "orchestratorRequest" and self.kind == "orchestrator") or ( + field == "activityRequest" and self.kind == "activity" + ) + + def WhichOneof(self, _): + return f"{self.kind}Request" + + +class DummyCompletionToken: + pass + + +def test_worker_concurrency_loop_async(): + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=1, + maximum_thread_pool_workers=2, + ) + grpc_worker = TaskHubGrpcWorker(concurrency_options=options) + stub = DummyStub() + + async def dummy_orchestrator(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteOrchestratorTask("ok") + + async def dummy_activity(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteActivityTask("ok") + + # Patch the worker's _execute_orchestrator and _execute_activity + grpc_worker._execute_orchestrator = dummy_orchestrator + grpc_worker._execute_activity = dummy_activity + + orchestrator_requests = [DummyRequest("orchestrator", f"orch{i}") for i in range(3)] + activity_requests = [DummyRequest("activity", f"act{i}") for i in range(4)] + + async def run_test(): + # Clear stub state before each run + stub.completed.clear() + worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) + for req in orchestrator_requests: + grpc_worker._async_worker_manager.submit_orchestration( + dummy_orchestrator, req, stub, DummyCompletionToken() + ) + for req in activity_requests: + grpc_worker._async_worker_manager.submit_activity( + dummy_activity, req, stub, DummyCompletionToken() + ) + await asyncio.sleep(1.0) + orchestrator_count = sum(1 for t, _ in stub.completed if t == "orchestrator") + activity_count = sum(1 for t, _ in stub.completed if t == "activity") + assert orchestrator_count == 3, ( + f"Expected 3 orchestrator completions, got {orchestrator_count}" + ) + assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" + grpc_worker._async_worker_manager._shutdown = True + await worker_task + + asyncio.run(run_test()) + asyncio.run(run_test()) diff --git a/tests/durabletask/test_activity_executor.py b/tests/durabletask/test_activity_executor.py index 996ae44..754c78e 100644 --- a/tests/durabletask/test_activity_executor.py +++ b/tests/durabletask/test_activity_executor.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import asyncio import json import logging from typing import Any, Optional, Tuple @@ -26,7 +27,9 @@ def test_activity(ctx: task.ActivityContext, test_input: Any): activity_input = "Hello, 世界!" executor, name = _get_activity_executor(test_activity) - result = executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + ) assert result is not None result_input, result_orchestration_id, result_task_id = json.loads(result) @@ -43,7 +46,7 @@ def test_activity(ctx: task.ActivityContext, _): caught_exception: Optional[Exception] = None try: - executor.execute(TEST_INSTANCE_ID, "Bogus", TEST_TASK_ID, None) + asyncio.run(executor.execute(TEST_INSTANCE_ID, "Bogus", TEST_TASK_ID, None)) except Exception as ex: caught_exception = ex @@ -56,3 +59,97 @@ def _get_activity_executor(fn: task.Activity) -> Tuple[worker._ActivityExecutor, name = registry.add_activity(fn) executor = worker._ActivityExecutor(registry, TEST_LOGGER) return executor, name + + +def test_async_activity_basic(): + """Validates basic async activity execution""" + + async def async_activity(ctx: task.ActivityContext, test_input: str): + # Simple async activity that returns modified input + return f"async:{test_input}" + + activity_input = "test" + executor, name = _get_activity_executor(async_activity) + + # Run the async executor + result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + ) + assert result is not None + + result_output = json.loads(result) + assert result_output == "async:test" + + +def test_async_activity_with_input(): + """Validates async activity with complex input/output""" + + async def async_activity(ctx: task.ActivityContext, test_input: dict): + # Return all activity inputs back as the output + return { + "input": test_input, + "orchestration_id": ctx.orchestration_id, + "task_id": ctx.task_id, + "processed": True, + } + + activity_input = {"key": "value", "number": 42} + executor, name = _get_activity_executor(async_activity) + result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + ) + assert result is not None + + result_data = json.loads(result) + assert result_data["input"] == activity_input + assert result_data["orchestration_id"] == TEST_INSTANCE_ID + assert result_data["task_id"] == TEST_TASK_ID + assert result_data["processed"] is True + + +def test_async_activity_with_await(): + """Validates async activity that performs async I/O""" + + async def async_activity_with_io(ctx: task.ActivityContext, delay: float): + # Simulate async I/O operation + await asyncio.sleep(delay) + return f"completed after {delay}s" + + activity_input = 0.01 # 10ms delay + executor, name = _get_activity_executor(async_activity_with_io) + result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + ) + assert result is not None + + result_output = json.loads(result) + assert result_output == "completed after 0.01s" + + +def test_mixed_sync_async_activities(): + """Validates that sync and async activities work together""" + + def sync_activity(ctx: task.ActivityContext, test_input: str): + return f"sync:{test_input}" + + async def async_activity(ctx: task.ActivityContext, test_input: str): + return f"async:{test_input}" + + registry = worker._Registry() + sync_name = registry.add_activity(sync_activity) + async_name = registry.add_activity(async_activity) + executor = worker._ActivityExecutor(registry, TEST_LOGGER) + + activity_input = "test" + + # Execute sync activity + sync_result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, sync_name, TEST_TASK_ID, json.dumps(activity_input)) + ) + assert json.loads(sync_result) == "sync:test" + + # Execute async activity + async_result = asyncio.run( + executor.execute(TEST_INSTANCE_ID, async_name, TEST_TASK_ID + 1, json.dumps(activity_input)) + ) + assert json.loads(async_result) == "async:test" diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index b71e70b..630e917 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -18,6 +18,142 @@ pytestmark = [pytest.mark.e2e, pytest.mark.asyncio] +async def test_orchestrator_with_async_activity(): + """Tests sync generator orchestrator calling async activity""" + + async def async_upper(ctx: task.ActivityContext, text: str) -> str: + # Async activity that converts text to uppercase + await asyncio.sleep(0.01) # Simulate async I/O + return text.upper() + + def orchestrator_with_async_activity(ctx: task.OrchestrationContext, input_text: str): + # Sync generator orchestrator calling async activity + result = yield ctx.call_activity(async_upper, input=input_text) + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator_with_async_activity) + w.add_activity(async_upper) + w.start() + + c = AsyncTaskHubGrpcClient() + input_text = "hello world" + id = await c.schedule_new_orchestration(orchestrator_with_async_activity, input=input_text) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("HELLO WORLD") + + +async def test_async_workflow_with_async_activity(): + """Tests async workflow calling async activity""" + from durabletask.aio import AsyncWorkflowContext + + async def async_process(ctx: task.ActivityContext, data: dict) -> dict: + # Async activity that processes data + await asyncio.sleep(0.01) # Simulate async I/O + return {"processed": True, "data": data} + + async def async_workflow_with_activity(ctx: AsyncWorkflowContext, input_data: dict): + # Async workflow calling async activity + result = await ctx.call_activity(async_process, input=input_data) + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(async_workflow_with_activity) + w.add_activity(async_process) + w.start() + + c = AsyncTaskHubGrpcClient() + input_data = {"key": "value"} + id = await c.schedule_new_orchestration(async_workflow_with_activity, input=input_data) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + output = json.loads(state.serialized_output) + assert output["processed"] is True + assert output["data"] == input_data + + +async def test_async_activity_with_retry_policy(): + """Tests retry policy with async activity failures""" + + attempt_count = {"value": 0} + + async def async_flaky_activity(ctx: task.ActivityContext, _) -> str: + # Async activity that fails first two attempts + attempt_count["value"] += 1 + await asyncio.sleep(0.01) + if attempt_count["value"] < 3: + raise RuntimeError(f"Attempt {attempt_count['value']} failed") + return "success" + + def orchestrator_with_retry(ctx: task.OrchestrationContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(milliseconds=100), + max_number_of_attempts=5, + ) + result = yield ctx.call_activity(async_flaky_activity, retry_policy=retry_policy) + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator_with_retry) + w.add_activity(async_flaky_activity) + w.start() + + c = AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator_with_retry) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("success") + assert attempt_count["value"] == 3 + + +async def test_async_activity_non_retryable_error(): + """Tests async activity with NonRetryableError""" + + attempt_count = {"value": 0} + + async def async_failing_activity(ctx: task.ActivityContext, _) -> str: + # Async activity that raises NonRetryableError + attempt_count["value"] += 1 + await asyncio.sleep(0.01) + raise task.NonRetryableError("This error should not be retried") + + def orchestrator_with_non_retryable(ctx: task.OrchestrationContext, _): + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(milliseconds=100), + max_number_of_attempts=5, + ) + result = yield ctx.call_activity(async_failing_activity, retry_policy=retry_policy) + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator_with_non_retryable) + w.add_activity(async_failing_activity) + w.start() + + c = AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator_with_non_retryable) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert "should not be retried" in state.failure_details.message + # Most importantly: verify the activity only ran once (no retries despite retry_policy) + assert attempt_count["value"] == 1 + + async def test_empty_orchestration(): invoked = False diff --git a/tests/durabletask/test_worker_grpc_errors.py b/tests/durabletask/test_worker_grpc_errors.py new file mode 100644 index 0000000..a0c4bb2 --- /dev/null +++ b/tests/durabletask/test_worker_grpc_errors.py @@ -0,0 +1,116 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from unittest.mock import MagicMock, Mock, patch + +import grpc + +from durabletask import worker + + +def test_execute_orchestrator_grpc_error_benign_cancelled(): + """Test that benign gRPC errors in orchestrator execution are handled gracefully.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy orchestrator + def test_orchestrator(ctx, input): + return "result" + + w.add_orchestrator(test_orchestrator) + + # Mock the stub to raise a benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.CANCELLED) + mock_stub.CompleteOrchestratorTask.side_effect = mock_error + + # Create a mock request with proper structure + mock_req = MagicMock() + mock_req.instanceId = "test-id" + mock_req.pastEvents = [] + mock_req.newEvents = [MagicMock()] + mock_req.newEvents[0].HasField = lambda x: x == "executionStarted" + mock_req.newEvents[0].executionStarted.name = "test_orchestrator" + mock_req.newEvents[0].executionStarted.input = None + mock_req.newEvents[0].router.targetAppID = None + mock_req.newEvents[0].router.sourceAppID = None + mock_req.newEvents[0].timestamp.ToDatetime = Mock(return_value=None) + + # Should not raise exception (benign error) + w._execute_orchestrator(mock_req, mock_stub, "token") + + +def test_execute_orchestrator_grpc_error_non_benign(): + """Test that non-benign gRPC errors in orchestrator execution are logged.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy orchestrator + def test_orchestrator(ctx, input): + return "result" + + w.add_orchestrator(test_orchestrator) + + # Mock the stub to raise a non-benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.INTERNAL) + mock_stub.CompleteOrchestratorTask.side_effect = mock_error + + # Create a mock request with proper structure + mock_req = MagicMock() + mock_req.instanceId = "test-id" + mock_req.pastEvents = [] + mock_req.newEvents = [MagicMock()] + mock_req.newEvents[0].HasField = lambda x: x == "executionStarted" + mock_req.newEvents[0].executionStarted.name = "test_orchestrator" + mock_req.newEvents[0].executionStarted.input = None + mock_req.newEvents[0].router.targetAppID = None + mock_req.newEvents[0].router.sourceAppID = None + mock_req.newEvents[0].timestamp.ToDatetime = Mock(return_value=None) + + # Should not raise exception (error is logged but handled) + with patch.object(w._logger, "exception") as mock_log: + w._execute_orchestrator(mock_req, mock_stub, "token") + # Verify error was logged + assert mock_log.called + + +def test_execute_activity_grpc_error_benign(): + """Test that benign gRPC errors in activity execution are handled gracefully.""" + w = worker.TaskHubGrpcWorker() + + # Add a dummy activity + def test_activity(ctx, input): + return "result" + + w.add_activity(test_activity) + + # Mock the stub to raise a benign error + mock_stub = MagicMock() + mock_error = grpc.RpcError() + mock_error.code = Mock(return_value=grpc.StatusCode.CANCELLED) + str_return = "unknown instance ID/task ID combo" + mock_error.__str__ = Mock(return_value=str_return) + mock_stub.CompleteActivityTask.side_effect = mock_error + + # Create a mock request + mock_req = MagicMock() + mock_req.orchestrationInstance.instanceId = "test-id" + mock_req.name = "test_activity" + mock_req.taskId = 1 + mock_req.input.value = '""' + + # Should not raise exception (benign error) + import asyncio + + asyncio.run(w._execute_activity(mock_req, mock_stub, "token")) diff --git a/tox.ini b/tox.ini index b6bc7ba..41ad883 100644 --- a/tox.ini +++ b/tox.ini @@ -10,8 +10,12 @@ runner = virtualenv [testenv] # you can run tox with the e2e pytest marker using tox factors: +# # start dapr sidecar (better than durabletask-go for multi-app executions) +# dapr init # maybe not needed if already done +# dapr run --app-id test-app --dapr-grpc-port 4001 --resources-path ./examples/components/ +# # In a separate terminal, run e2e tests (appends to .coverage) # tox -e py310-e2e -# to use custom grpc endpoint: +# to use custom grpc endpoint use the DAPR_GRPC_ENDPOINT environment variable: # DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e setenv = PYTHONDONTWRITEBYTECODE=1