Skip to content

Commit 6167422

Browse files
authored
Support Async WF and Step functions (#168)
This PR introduces `Outcome[T]` - a type to allow us to polymorphically compose decorator behavior for sync and async methods. Updates WF and Step decorators to use `Outcome[T]` in order to support sync and async WF/step functions. Adds async versions of DBOS APIs `send/recv/get_event/set_event/sleep` for use in async workflows. Note, this PR does *NOT* add support for async transactions. Due to design of SQLAlchemy's async support, adding async tx function support is more involved and will be addressed in a future PR. fixes #112
1 parent 2aa453a commit 6167422

File tree

8 files changed

+965
-104
lines changed

8 files changed

+965
-104
lines changed

dbos/_core.py

Lines changed: 175 additions & 93 deletions
Large diffs are not rendered by default.

dbos/_dbos.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import atexit
45
import json
56
import os
@@ -13,6 +14,7 @@
1314
TYPE_CHECKING,
1415
Any,
1516
Callable,
17+
Coroutine,
1618
Generic,
1719
List,
1820
Literal,
@@ -21,6 +23,9 @@
2123
Tuple,
2224
Type,
2325
TypeVar,
26+
Union,
27+
cast,
28+
overload,
2429
)
2530

2631
from opentelemetry.trace import Span
@@ -71,6 +76,7 @@
7176
from ._admin_server import AdminServer
7277
from ._app_db import ApplicationDatabase
7378
from ._context import (
79+
DBOSContext,
7480
EnterDBOSStep,
7581
TracedAttributes,
7682
assert_current_dbos_context,
@@ -432,7 +438,7 @@ def register_instance(cls, inst: object) -> None:
432438
@classmethod
433439
def workflow(
434440
cls, *, max_recovery_attempts: int = DEFAULT_MAX_RECOVERY_ATTEMPTS
435-
) -> Callable[[F], F]:
441+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
436442
"""Decorate a function for use as a DBOS workflow."""
437443
return decorate_workflow(_get_or_create_dbos_registry(), max_recovery_attempts)
438444

@@ -457,7 +463,7 @@ def step(
457463
interval_seconds: float = 1.0,
458464
max_attempts: int = 3,
459465
backoff_rate: float = 2.0,
460-
) -> Callable[[F], F]:
466+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
461467
"""
462468
Decorate and configure a function for use as a DBOS step.
463469
@@ -542,15 +548,36 @@ def kafka_consumer(
542548
f"{e.name} dependency not found. Please install {e.name} via your package manager."
543549
) from e
544550

551+
@overload
552+
@classmethod
553+
def start_workflow(
554+
cls,
555+
func: Workflow[P, Coroutine[Any, Any, R]],
556+
*args: P.args,
557+
**kwargs: P.kwargs,
558+
) -> WorkflowHandle[R]: ...
559+
560+
@overload
545561
@classmethod
546562
def start_workflow(
547563
cls,
548564
func: Workflow[P, R],
549565
*args: P.args,
550566
**kwargs: P.kwargs,
567+
) -> WorkflowHandle[R]: ...
568+
569+
@classmethod
570+
def start_workflow(
571+
cls,
572+
func: Workflow[P, Union[R, Coroutine[Any, Any, R]]],
573+
*args: P.args,
574+
**kwargs: P.kwargs,
551575
) -> WorkflowHandle[R]:
552576
"""Invoke a workflow function in the background, returning a handle to the ongoing execution."""
553-
return start_workflow(_get_dbos_instance(), func, None, True, *args, **kwargs)
577+
return cast(
578+
WorkflowHandle[R],
579+
start_workflow(_get_dbos_instance(), func, None, True, *args, **kwargs),
580+
)
554581

555582
@classmethod
556583
def get_workflow_status(cls, workflow_id: str) -> Optional[WorkflowStatus]:
@@ -602,6 +629,13 @@ def send(
602629
"""Send a message to a workflow execution."""
603630
return send(_get_dbos_instance(), destination_id, message, topic)
604631

632+
@classmethod
633+
async def send_async(
634+
cls, destination_id: str, message: Any, topic: Optional[str] = None
635+
) -> None:
636+
"""Send a message to a workflow execution."""
637+
await asyncio.to_thread(lambda: DBOS.send(destination_id, message, topic))
638+
605639
@classmethod
606640
def recv(cls, topic: Optional[str] = None, timeout_seconds: float = 60) -> Any:
607641
"""
@@ -612,13 +646,25 @@ def recv(cls, topic: Optional[str] = None, timeout_seconds: float = 60) -> Any:
612646
"""
613647
return recv(_get_dbos_instance(), topic, timeout_seconds)
614648

649+
@classmethod
650+
async def recv_async(
651+
cls, topic: Optional[str] = None, timeout_seconds: float = 60
652+
) -> Any:
653+
"""
654+
Receive a workflow message.
655+
656+
This function is to be called from within a workflow.
657+
`recv_async` will return the message sent on `topic`, asyncronously waiting if necessary.
658+
"""
659+
return await asyncio.to_thread(lambda: DBOS.recv(topic, timeout_seconds))
660+
615661
@classmethod
616662
def sleep(cls, seconds: float) -> None:
617663
"""
618664
Sleep for the specified time (in seconds).
619665
620-
It is important to use `DBOS.sleep` (as opposed to any other sleep) within workflows,
621-
as the `DBOS.sleep`s are durable and completed sleeps will be skipped during recovery.
666+
It is important to use `DBOS.sleep` or `DBOS.sleep_async` (as opposed to any other sleep) within workflows,
667+
as the DBOS sleep methods are durable and completed sleeps will be skipped during recovery.
622668
"""
623669
if seconds <= 0:
624670
return
@@ -631,25 +677,34 @@ def sleep(cls, seconds: float) -> None:
631677
attributes: TracedAttributes = {
632678
"name": "sleep",
633679
}
634-
with EnterDBOSStep(attributes) as ctx:
680+
with EnterDBOSStep(attributes):
681+
ctx = assert_current_dbos_context()
635682
_get_dbos_instance()._sys_db.sleep(
636683
ctx.workflow_id, ctx.curr_step_function_id, seconds
637684
)
638685
else:
639686
# Cannot call it from outside of a workflow
640687
raise DBOSException("sleep() must be called from within a workflow")
641688

689+
@classmethod
690+
async def sleep_async(cls, seconds: float) -> None:
691+
"""
692+
Sleep for the specified time (in seconds).
693+
694+
It is important to use `DBOS.sleep` or `DBOS.sleep_async` (as opposed to any other sleep) within workflows,
695+
as the DBOS sleep methods are durable and completed sleeps will be skipped during recovery.
696+
"""
697+
await asyncio.to_thread(lambda: DBOS.sleep(seconds))
698+
642699
@classmethod
643700
def set_event(cls, key: str, value: Any) -> None:
644701
"""
645702
Set a workflow event.
646703
647-
This function is to be called from within a workflow.
648-
649704
`set_event` sets the `value` of `key` for the current workflow instance ID.
650705
This `value` can then be retrieved by other functions, using `get_event` below.
651-
652-
Each workflow invocation should only call set_event once per `key`.
706+
If the event `key` already exists, its `value` is updated.
707+
This function can only be called from within a workflow.
653708
654709
Args:
655710
key(str): The event key / name within the workflow
@@ -658,6 +713,23 @@ def set_event(cls, key: str, value: Any) -> None:
658713
"""
659714
return set_event(_get_dbos_instance(), key, value)
660715

716+
@classmethod
717+
async def set_event_async(cls, key: str, value: Any) -> None:
718+
"""
719+
Set a workflow event.
720+
721+
`set_event_async` sets the `value` of `key` for the current workflow instance ID.
722+
This `value` can then be retrieved by other functions, using `get_event` below.
723+
If the event `key` already exists, its `value` is updated.
724+
This function can only be called from within a workflow.
725+
726+
Args:
727+
key(str): The event key / name within the workflow
728+
value(Any): A serializable value to associate with the key
729+
730+
"""
731+
await asyncio.to_thread(lambda: DBOS.set_event(key, value))
732+
661733
@classmethod
662734
def get_event(cls, workflow_id: str, key: str, timeout_seconds: float = 60) -> Any:
663735
"""
@@ -673,6 +745,25 @@ def get_event(cls, workflow_id: str, key: str, timeout_seconds: float = 60) -> A
673745
"""
674746
return get_event(_get_dbos_instance(), workflow_id, key, timeout_seconds)
675747

748+
@classmethod
749+
async def get_event_async(
750+
cls, workflow_id: str, key: str, timeout_seconds: float = 60
751+
) -> Any:
752+
"""
753+
Return the `value` of a workflow event, waiting for it to occur if necessary.
754+
755+
`get_event_async` waits for a corresponding `set_event` by the workflow with ID `workflow_id` with the same `key`.
756+
757+
Args:
758+
workflow_id(str): The workflow instance ID that is expected to call `set_event` on `key`
759+
key(str): The event key / name within the workflow
760+
timeout_seconds(float): The amount of time to wait, in case `set_event` has not yet been called byt the workflow
761+
762+
"""
763+
return await asyncio.to_thread(
764+
lambda: DBOS.get_event(workflow_id, key, timeout_seconds)
765+
)
766+
676767
@classmethod
677768
def execute_workflow_id(cls, workflow_id: str) -> WorkflowHandle[Any]:
678769
"""Execute a workflow by ID (for recovery)."""

0 commit comments

Comments
 (0)