Skip to content

Commit 9998cd8

Browse files
authored
Support stacked decorators for async workflows (#192)
Currently, You can't declare an async function to be both a FastAPI handler and a DBOS workflow: ```python @app.get("/endpoint/{var1}/{var2}") @DBOS.workflow() async def test_endpoint(var1: str, var2: str) -> str: return f"{var1}, {var2}!" ``` The problem stems from the function returned by the `@DBOS.workflow` decorator. Both the decorated function and the function returned by the decorator return a coroutine, but the function returned by the decorator is not defined with `async def`. This causes FastAPI to mis-categorize the workflow function as sync when it is actually async. The function retuned by the decorator has to appear as a coroutine to `inspect.iscoroutinefunction`. For Python 3.12 and later, any function can be marked as a coroutine using [`inspect.markcoroutinefunction`](https://docs.python.org/3/library/inspect.html#inspect.markcoroutinefunction). For Python 3.11 and earlier, we have to wrap the coroutine returning function in an async function like this: ```python def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]: @wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> R: return await func(*args, **kwargs) # type: ignore return async_wrapper # type: ignore ```
1 parent 862a59e commit 9998cd8

File tree

3 files changed

+632
-24
lines changed

3 files changed

+632
-24
lines changed

dbos/_core.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,22 @@ def start_workflow(
488488
return WorkflowHandleFuture(new_wf_id, future, dbos)
489489

490490

491+
if sys.version_info < (3, 12):
492+
493+
def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]:
494+
@wraps(func)
495+
async def async_wrapper(*args: Any, **kwargs: Any) -> R:
496+
return await func(*args, **kwargs) # type: ignore
497+
498+
return async_wrapper # type: ignore
499+
500+
else:
501+
502+
def _mark_coroutine(func: Callable[P, R]) -> Callable[P, R]:
503+
inspect.markcoroutinefunction(func)
504+
return func
505+
506+
491507
def workflow_wrapper(
492508
dbosreg: "DBOSRegistry",
493509
func: Callable[P, R],
@@ -548,7 +564,7 @@ def init_wf() -> Callable[[Callable[[], R]], R]:
548564
)
549565
return outcome() # type: ignore
550566

551-
return wrapper
567+
return _mark_coroutine(wrapper) if inspect.iscoroutinefunction(func) else wrapper
552568

553569

554570
def decorate_workflow(
@@ -838,6 +854,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
838854
assert tempwf
839855
return tempwf(*args, **kwargs)
840856

857+
wrapper = (
858+
_mark_coroutine(wrapper) if inspect.iscoroutinefunction(func) else wrapper # type: ignore
859+
)
860+
841861
def temp_wf_sync(*args: Any, **kwargs: Any) -> Any:
842862
return wrapper(*args, **kwargs)
843863

0 commit comments

Comments
 (0)