Skip to content

Commit 6b15f51

Browse files
committed
allow async activity definitions
Signed-off-by: Filinto Duran <[email protected]>
1 parent 5136aa2 commit 6b15f51

File tree

8 files changed

+343
-11
lines changed

8 files changed

+343
-11
lines changed

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,35 @@ Orchestrations are implemented using ordinary Python functions that take an `Orc
138138

139139
Activities are implemented using ordinary Python functions that take an `ActivityContext` as their first parameter. Activity functions are scheduled by orchestrations and have at-least-once execution guarantees, meaning that they will be executed at least once but may be executed multiple times in the event of a transient failure. Activity functions are where the real "work" of any orchestration is done.
140140

141+
#### Async Activities
142+
143+
Activities can be either synchronous or asynchronous functions. Async activities are useful for I/O-bound operations like HTTP requests, database queries, or file operations:
144+
145+
```python
146+
from durabletask.task import ActivityContext
147+
148+
# Synchronous activity
149+
def sync_activity(ctx: ActivityContext, data: str) -> str:
150+
return data.upper()
151+
152+
# Asynchronous activity
153+
async def async_activity(ctx: ActivityContext, data: str) -> str:
154+
# Perform async I/O operations
155+
async with aiohttp.ClientSession() as session:
156+
async with session.get(f"https://api.example.com/{data}") as response:
157+
result = await response.json()
158+
return result
159+
```
160+
161+
Both sync and async activities are registered the same way:
162+
163+
```python
164+
worker.add_activity(sync_activity)
165+
worker.add_activity(async_activity)
166+
```
167+
168+
Orchestrators call them identically regardless of whether they're sync or async - the SDK handles the execution automatically.
169+
141170
### Durable timers
142171

143172
Orchestrations can schedule durable timers using the `create_timer` API. These timers are durable, meaning that they will survive orchestrator restarts and will fire even if the orchestrator is not actively in memory. Durable timers can be of any duration, from milliseconds to months.

durabletask/aio/ASYNCIO_INTERNALS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ Adding sandbox coverage:
292292

293293
## Interop Checklist (Async ↔ Generator)
294294

295-
- Activities: identical behavior; only authoring differs (`yield` vs `await`).
295+
- Activities: identical behavior; only authoring differs (`yield` vs `await`). Activities themselves can be either sync or async functions and work identically from both generator and async orchestrators.
296296
- Timers: map to the same `createTimer` actions.
297297
- External events: same semantics for buffering and completion.
298298
- Sub‑orchestrators: same create/complete/fail events.

durabletask/task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
from abc import ABC, abstractmethod
99
from datetime import datetime, timedelta
10-
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
10+
from typing import Any, Awaitable, Callable, Generator, Generic, Optional, TypeVar, Union
1111

1212
import durabletask.internal.helpers as pbh
1313
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -499,7 +499,10 @@ def task_id(self) -> int:
499499
Orchestrator = Callable[[OrchestrationContext, TInput], Union[Generator[Task, Any, Any], TOutput]]
500500

501501
# Activities are simple functions that can be scheduled by orchestrators
502-
Activity = Callable[[ActivityContext, TInput], TOutput]
502+
Activity = Union[
503+
Callable[[ActivityContext, TInput], TOutput],
504+
Callable[[ActivityContext, TInput], Awaitable[TOutput]],
505+
]
503506

504507

505508
class RetryPolicy:

durabletask/worker.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def _execute_orchestrator(
692692
f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
693693
)
694694

695-
def _execute_activity(
695+
async def _execute_activity(
696696
self,
697697
req: pb.ActivityRequest,
698698
stub: stubs.TaskHubSidecarServiceStub,
@@ -701,7 +701,7 @@ def _execute_activity(
701701
instance_id = req.orchestrationInstance.instanceId
702702
try:
703703
executor = _ActivityExecutor(self._registry, self._logger)
704-
result = executor.execute(instance_id, req.name, req.taskId, req.input.value)
704+
result = await executor.execute(instance_id, req.name, req.taskId, req.input.value)
705705
res = pb.ActivityResponse(
706706
instanceId=instance_id,
707707
taskId=req.taskId,
@@ -1441,7 +1441,7 @@ def __init__(self, registry: _Registry, logger: logging.Logger):
14411441
self._registry = registry
14421442
self._logger = logger
14431443

1444-
def execute(
1444+
async def execute(
14451445
self,
14461446
orchestration_id: str,
14471447
name: str,
@@ -1459,8 +1459,11 @@ def execute(
14591459
activity_input = shared.from_json(encoded_input) if encoded_input else None
14601460
ctx = task.ActivityContext(orchestration_id, task_id)
14611461

1462-
# Execute the activity function
1463-
activity_output = fn(ctx, activity_input)
1462+
# Execute the activity function (sync or async)
1463+
if inspect.iscoroutinefunction(fn):
1464+
activity_output = await fn(ctx, activity_input)
1465+
else:
1466+
activity_output = fn(ctx, activity_input)
14641467

14651468
encoded_output = shared.to_json(activity_output) if activity_output is not None else None
14661469
chars = len(encoded_output) if encoded_output else 0

tests/aio/test_context_simple.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
These tests focus on the actual implementation rather than expected features.
1616
"""
1717

18+
import asyncio
1819
import random
1920
import uuid
2021
from datetime import datetime, timedelta
@@ -308,3 +309,64 @@ def test_context_repr(self):
308309
repr_str = repr(self.ctx)
309310
assert "AsyncWorkflowContext" in repr_str
310311
assert "test-instance-123" in repr_str
312+
313+
314+
class TestAsyncActivities:
315+
"""Test async activities called from AsyncWorkflowContext."""
316+
317+
def setup_method(self):
318+
"""Set up test fixtures."""
319+
self.mock_base_ctx = Mock(spec=dt_task.OrchestrationContext)
320+
self.mock_base_ctx.instance_id = "test-instance-123"
321+
self.mock_base_ctx.current_utc_datetime = datetime(2023, 1, 1, 12, 0, 0)
322+
self.mock_base_ctx.is_replaying = False
323+
self.mock_base_ctx.is_suspended = False
324+
325+
def test_async_activity_call(self):
326+
"""Test AsyncWorkflowContext calling async activity"""
327+
328+
async def async_activity(ctx: dt_task.ActivityContext, input_data: str):
329+
await asyncio.sleep(0.001)
330+
return input_data.upper()
331+
332+
ctx = AsyncWorkflowContext(self.mock_base_ctx)
333+
awaitable = ctx.call_activity(async_activity, input="test")
334+
335+
assert isinstance(awaitable, ActivityAwaitable)
336+
assert awaitable._activity_fn == async_activity
337+
assert awaitable._input == "test"
338+
339+
def test_async_activity_with_when_all(self):
340+
"""Test when_all with async activities"""
341+
342+
async def async_activity(ctx: dt_task.ActivityContext, input_data: int):
343+
await asyncio.sleep(0.001)
344+
return input_data * 2
345+
346+
ctx = AsyncWorkflowContext(self.mock_base_ctx)
347+
348+
# Create multiple async activity awaitables
349+
awaitables = [ctx.call_activity(async_activity, input=i) for i in range(3)]
350+
when_all = ctx.when_all(awaitables)
351+
352+
assert isinstance(when_all, WhenAllAwaitable)
353+
354+
def test_async_activity_with_when_any(self):
355+
"""Test when_any with async activities"""
356+
357+
async def async_activity_fast(ctx: dt_task.ActivityContext, _):
358+
await asyncio.sleep(0.001)
359+
return "fast"
360+
361+
async def async_activity_slow(ctx: dt_task.ActivityContext, _):
362+
await asyncio.sleep(0.1)
363+
return "slow"
364+
365+
ctx = AsyncWorkflowContext(self.mock_base_ctx)
366+
367+
# Create async activity awaitables
368+
fast_awaitable = ctx.call_activity(async_activity_fast, input=None)
369+
slow_awaitable = ctx.call_activity(async_activity_slow, input=None)
370+
when_any = ctx.when_any([fast_awaitable, slow_awaitable])
371+
372+
assert isinstance(when_any, WhenAnyAwaitable)

tests/durabletask/test_activity_executor.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
import asyncio
45
import json
56
import logging
67
from typing import Any, Optional, Tuple
@@ -26,7 +27,9 @@ def test_activity(ctx: task.ActivityContext, test_input: Any):
2627

2728
activity_input = "Hello, 世界!"
2829
executor, name = _get_activity_executor(test_activity)
29-
result = executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input))
30+
result = asyncio.run(
31+
executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input))
32+
)
3033
assert result is not None
3134

3235
result_input, result_orchestration_id, result_task_id = json.loads(result)
@@ -43,7 +46,7 @@ def test_activity(ctx: task.ActivityContext, _):
4346

4447
caught_exception: Optional[Exception] = None
4548
try:
46-
executor.execute(TEST_INSTANCE_ID, "Bogus", TEST_TASK_ID, None)
49+
asyncio.run(executor.execute(TEST_INSTANCE_ID, "Bogus", TEST_TASK_ID, None))
4750
except Exception as ex:
4851
caught_exception = ex
4952

@@ -56,3 +59,97 @@ def _get_activity_executor(fn: task.Activity) -> Tuple[worker._ActivityExecutor,
5659
name = registry.add_activity(fn)
5760
executor = worker._ActivityExecutor(registry, TEST_LOGGER)
5861
return executor, name
62+
63+
64+
def test_async_activity_basic():
65+
"""Validates basic async activity execution"""
66+
67+
async def async_activity(ctx: task.ActivityContext, test_input: str):
68+
# Simple async activity that returns modified input
69+
return f"async:{test_input}"
70+
71+
activity_input = "test"
72+
executor, name = _get_activity_executor(async_activity)
73+
74+
# Run the async executor
75+
result = asyncio.run(
76+
executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input))
77+
)
78+
assert result is not None
79+
80+
result_output = json.loads(result)
81+
assert result_output == "async:test"
82+
83+
84+
def test_async_activity_with_input():
85+
"""Validates async activity with complex input/output"""
86+
87+
async def async_activity(ctx: task.ActivityContext, test_input: dict):
88+
# Return all activity inputs back as the output
89+
return {
90+
"input": test_input,
91+
"orchestration_id": ctx.orchestration_id,
92+
"task_id": ctx.task_id,
93+
"processed": True,
94+
}
95+
96+
activity_input = {"key": "value", "number": 42}
97+
executor, name = _get_activity_executor(async_activity)
98+
result = asyncio.run(
99+
executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input))
100+
)
101+
assert result is not None
102+
103+
result_data = json.loads(result)
104+
assert result_data["input"] == activity_input
105+
assert result_data["orchestration_id"] == TEST_INSTANCE_ID
106+
assert result_data["task_id"] == TEST_TASK_ID
107+
assert result_data["processed"] is True
108+
109+
110+
def test_async_activity_with_await():
111+
"""Validates async activity that performs async I/O"""
112+
113+
async def async_activity_with_io(ctx: task.ActivityContext, delay: float):
114+
# Simulate async I/O operation
115+
await asyncio.sleep(delay)
116+
return f"completed after {delay}s"
117+
118+
activity_input = 0.01 # 10ms delay
119+
executor, name = _get_activity_executor(async_activity_with_io)
120+
result = asyncio.run(
121+
executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input))
122+
)
123+
assert result is not None
124+
125+
result_output = json.loads(result)
126+
assert result_output == "completed after 0.01s"
127+
128+
129+
def test_mixed_sync_async_activities():
130+
"""Validates that sync and async activities work together"""
131+
132+
def sync_activity(ctx: task.ActivityContext, test_input: str):
133+
return f"sync:{test_input}"
134+
135+
async def async_activity(ctx: task.ActivityContext, test_input: str):
136+
return f"async:{test_input}"
137+
138+
registry = worker._Registry()
139+
sync_name = registry.add_activity(sync_activity)
140+
async_name = registry.add_activity(async_activity)
141+
executor = worker._ActivityExecutor(registry, TEST_LOGGER)
142+
143+
activity_input = "test"
144+
145+
# Execute sync activity
146+
sync_result = asyncio.run(
147+
executor.execute(TEST_INSTANCE_ID, sync_name, TEST_TASK_ID, json.dumps(activity_input))
148+
)
149+
assert json.loads(sync_result) == "sync:test"
150+
151+
# Execute async activity
152+
async_result = asyncio.run(
153+
executor.execute(TEST_INSTANCE_ID, async_name, TEST_TASK_ID + 1, json.dumps(activity_input))
154+
)
155+
assert json.loads(async_result) == "async:test"

0 commit comments

Comments
 (0)