Skip to content

Commit b07037c

Browse files
authored
DBOS.asyncio_wait (#609)
`DBOS.asyncio_wait` is a durable implementation of [`asyncio.wait`](https://docs.python.org/3/library/asyncio-task.html#asyncio.wait) with the same interface and semantics. In particular, this lets you durably wait for the first of many steps to finish (using `asyncio.FIRST_COMPLETED`), or wait for step completion with a timeout.
1 parent 71b76c8 commit b07037c

File tree

2 files changed

+261
-0
lines changed

2 files changed

+261
-0
lines changed

dbos/_dbos.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TYPE_CHECKING,
1616
Any,
1717
AsyncGenerator,
18+
Awaitable,
1819
Callable,
1920
Coroutine,
2021
Dict,
@@ -38,6 +39,8 @@
3839
DefaultSerializer,
3940
Serializer,
4041
WorkflowSerializationFormat,
42+
deserialize_value,
43+
serialize_value,
4144
)
4245
from dbos._sys_db import SystemDatabase, WorkflowStatus
4346
from dbos._utils import INTERNAL_QUEUE_NAME, GlobalParams, generate_uuid
@@ -1283,6 +1286,80 @@ async def sleep_async(cls, seconds: float) -> None:
12831286
)
12841287
await asyncio.sleep(duration)
12851288

1289+
@classmethod
1290+
async def asyncio_wait(
1291+
cls,
1292+
fs: List[Awaitable[Any]],
1293+
*,
1294+
timeout: Optional[float] = None,
1295+
return_when: str = asyncio.ALL_COMPLETED,
1296+
) -> tuple[set[asyncio.Task[Any]], set[asyncio.Task[Any]]]:
1297+
"""
1298+
Durable version of asyncio.wait.
1299+
1300+
Checkpoints which tasks are done vs pending so the result is
1301+
deterministic during workflow recovery.
1302+
1303+
When called outside a workflow, falls back to regular `asyncio.wait`.
1304+
"""
1305+
cur_ctx = snapshot_step_context()
1306+
fs_list = [asyncio.ensure_future(f) for f in fs]
1307+
if cur_ctx is None or not cur_ctx.is_workflow():
1308+
result: tuple[set[Any], set[Any]] = await asyncio.wait(
1309+
fs_list, timeout=timeout, return_when=return_when
1310+
)
1311+
return result
1312+
1313+
await cls._configure_asyncio_thread_pool()
1314+
dbos = _get_dbos_instance()
1315+
attributes: TracedAttributes = {"name": "asyncio_wait"}
1316+
1317+
with EnterDBOSStepCtx(attributes, cur_ctx) as ctx:
1318+
recorded = await asyncio.to_thread(
1319+
dbos._sys_db.check_operation_execution,
1320+
ctx.workflow_id,
1321+
ctx.curr_step_function_id,
1322+
"DBOS.asyncio_wait",
1323+
)
1324+
1325+
if recorded is not None:
1326+
recorded_indices: set[int] = set(
1327+
deserialize_value(
1328+
recorded["output"],
1329+
recorded["serialization"],
1330+
dbos._serializer,
1331+
)
1332+
or []
1333+
)
1334+
done = {fs_list[i] for i in recorded_indices}
1335+
pending = {
1336+
fs_list[i] for i in range(len(fs_list)) if i not in recorded_indices
1337+
}
1338+
if done:
1339+
await asyncio.wait(done)
1340+
return done, pending
1341+
else:
1342+
done, pending = await asyncio.wait(
1343+
fs_list, timeout=timeout, return_when=return_when
1344+
)
1345+
done_idx_list = [i for i, f in enumerate(fs_list) if f in done]
1346+
serialized_output, serialization = serialize_value(
1347+
done_idx_list, None, dbos._serializer
1348+
)
1349+
await asyncio.to_thread(
1350+
dbos._sys_db.record_operation_result,
1351+
{
1352+
"workflow_uuid": ctx.workflow_id,
1353+
"function_id": ctx.curr_step_function_id,
1354+
"function_name": "DBOS.asyncio_wait",
1355+
"output": serialized_output,
1356+
"error": None,
1357+
"serialization": serialization,
1358+
"started_at_epoch_ms": int(time.time() * 1000),
1359+
},
1360+
)
1361+
return done, pending
1362+
12861363
@classmethod
12871364
def set_event(
12881365
cls,

tests/test_async.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,190 @@ async def parent_workflow() -> tuple[str, str, str]:
691691
assert async_child_status.parent_workflow_id == parent_id
692692

693693

694+
@pytest.mark.asyncio
695+
async def test_asyncio_wait(dbos: DBOS) -> None:
696+
step_counter: int = 0
697+
gate = asyncio.Event()
698+
699+
@DBOS.step()
700+
async def fast_step(val: str) -> str:
701+
nonlocal step_counter
702+
step_counter += 1
703+
return val + "_done"
704+
705+
@DBOS.step()
706+
async def slow_step(val: str) -> str:
707+
nonlocal step_counter
708+
step_counter += 1
709+
await gate.wait()
710+
return val + "_done"
711+
712+
@DBOS.workflow()
713+
async def wait_workflow() -> None:
714+
done, pending = await DBOS.asyncio_wait(
715+
[fast_step("fast"), slow_step("slow")],
716+
return_when=asyncio.FIRST_COMPLETED,
717+
)
718+
719+
assert len(done) == 1
720+
assert len(pending) == 1
721+
assert [t.result() for t in done] == ["fast_done"]
722+
723+
# Let the slow step finish and wait for it
724+
gate.set()
725+
done2, pending2 = await DBOS.asyncio_wait(list(pending))
726+
assert len(done2) == 1
727+
assert len(pending2) == 0
728+
assert [t.result() for t in done2] == ["slow_done"]
729+
730+
handle = await DBOS.start_workflow_async(wait_workflow)
731+
await handle.get_result()
732+
assert step_counter == 2
733+
734+
# Verify recorded steps
735+
steps = await DBOS.list_workflow_steps_async(handle.workflow_id)
736+
assert len(steps) == 4
737+
# Step 1: first asyncio_wait snapshots its context before the tasks run.
738+
# Recorded done indices [0] means the first future (fast_step) completed.
739+
assert steps[0]["function_id"] == 1
740+
assert steps[0]["function_name"] == "DBOS.asyncio_wait"
741+
assert steps[0]["output"] == [0]
742+
# Steps 2 & 3: the step coroutines execute inside the asyncio tasks
743+
assert steps[1]["function_id"] == 2
744+
assert steps[1]["function_name"] == fast_step.__qualname__
745+
assert steps[1]["output"] == "fast_done"
746+
assert steps[2]["function_id"] == 3
747+
assert steps[2]["function_name"] == slow_step.__qualname__
748+
assert steps[2]["output"] == "slow_done"
749+
# Step 4: second asyncio_wait on the pending set
750+
assert steps[3]["function_id"] == 4
751+
assert steps[3]["function_name"] == "DBOS.asyncio_wait"
752+
assert steps[3]["output"] == [0]
753+
754+
# Fork from a high step to replay everything from DB (OAOO)
755+
forked = await DBOS.fork_workflow_async(handle.workflow_id, 100)
756+
await forked.get_result()
757+
assert step_counter == 2
758+
759+
760+
@pytest.mark.asyncio
761+
async def test_asyncio_wait_all_completed(dbos: DBOS) -> None:
762+
step_counter: int = 0
763+
764+
@DBOS.step()
765+
async def my_step(val: str) -> str:
766+
nonlocal step_counter
767+
step_counter += 1
768+
return val
769+
770+
@DBOS.workflow()
771+
async def wait_all_workflow() -> None:
772+
done, pending = await DBOS.asyncio_wait(
773+
[my_step("a"), my_step("b")], return_when=asyncio.ALL_COMPLETED
774+
)
775+
776+
assert len(done) == 2
777+
assert len(pending) == 0
778+
assert sorted([t.result() for t in done]) == ["a", "b"]
779+
780+
handle = await DBOS.start_workflow_async(wait_all_workflow)
781+
await handle.get_result()
782+
assert step_counter == 2
783+
784+
# Fork from a high step to replay everything from DB (OAOO)
785+
forked = await DBOS.fork_workflow_async(handle.workflow_id, 100)
786+
await forked.get_result()
787+
assert step_counter == 2
788+
789+
790+
@pytest.mark.asyncio
791+
async def test_asyncio_wait_first_exception(dbos: DBOS) -> None:
792+
step_counter: int = 0
793+
gate = asyncio.Event()
794+
795+
@DBOS.step()
796+
async def error_step() -> str:
797+
nonlocal step_counter
798+
step_counter += 1
799+
raise ValueError("boom")
800+
801+
@DBOS.step()
802+
async def slow_step(val: str) -> str:
803+
nonlocal step_counter
804+
step_counter += 1
805+
await gate.wait()
806+
return val
807+
808+
@DBOS.workflow()
809+
async def wait_exception_workflow() -> None:
810+
done, pending = await DBOS.asyncio_wait(
811+
[error_step(), slow_step("ok")],
812+
return_when=asyncio.FIRST_EXCEPTION,
813+
)
814+
815+
assert len(done) == 1
816+
assert len(pending) == 1
817+
task = next(iter(done))
818+
assert isinstance(task.exception(), ValueError)
819+
assert "boom" in str(task.exception())
820+
821+
# Let the slow step finish and wait for it
822+
gate.set()
823+
done2, pending2 = await DBOS.asyncio_wait(list(pending))
824+
assert len(done2) == 1
825+
assert len(pending2) == 0
826+
assert next(iter(done2)).result() == "ok"
827+
828+
handle = await DBOS.start_workflow_async(wait_exception_workflow)
829+
await handle.get_result()
830+
assert step_counter == 2
831+
832+
# Fork from a high step to replay everything from DB (OAOO)
833+
forked = await DBOS.fork_workflow_async(handle.workflow_id, 100)
834+
await forked.get_result()
835+
assert step_counter == 2
836+
837+
838+
@pytest.mark.asyncio
839+
async def test_asyncio_wait_timeout(dbos: DBOS) -> None:
840+
step_counter: int = 0
841+
gate = asyncio.Event()
842+
843+
@DBOS.step()
844+
async def blocked_step(val: str) -> str:
845+
nonlocal step_counter
846+
step_counter += 1
847+
await gate.wait()
848+
return val
849+
850+
@DBOS.workflow()
851+
async def wait_timeout_workflow() -> None:
852+
done, pending = await DBOS.asyncio_wait(
853+
[blocked_step("a"), blocked_step("b")],
854+
timeout=0.1,
855+
)
856+
857+
# Both should still be pending after timeout
858+
assert len(done) == 0
859+
assert len(pending) == 2
860+
861+
# Unblock and wait for all
862+
gate.set()
863+
done2, pending2 = await DBOS.asyncio_wait(list(pending))
864+
assert len(done2) == 2
865+
assert len(pending2) == 0
866+
assert sorted([t.result() for t in done2]) == ["a", "b"]
867+
868+
handle = await DBOS.start_workflow_async(wait_timeout_workflow)
869+
await handle.get_result()
870+
assert step_counter == 2
871+
872+
# Fork from a high step to replay everything from DB (OAOO)
873+
forked = await DBOS.fork_workflow_async(handle.workflow_id, 100)
874+
await forked.get_result()
875+
assert step_counter == 2
876+
877+
694878
@pytest.mark.asyncio
695879
async def test_workflow_recovery_async(dbos: DBOS, config: DBOSConfig) -> None:
696880
DBOS.destroy(destroy_registry=True)

0 commit comments

Comments
 (0)