Skip to content

Commit 381dc4b

Browse files
authored
Fix destroy for asyncio (#451)
This PR fixes a bug where `DBOS.destroy()` shuts down the executor pool but leaves it attached to the event loop. As a result, if DBOS is re-initialized and async functions are invoked directly, the loop still tries to use the destroyed pool and raises a runtime error. The fix is to let the wrapper functions optionally accept a `dbos` instance. This makes sure the event loop is always configured correctly, even when async workflows are called directly after re-initialization. Added a test for destroy semantics for asyncio.
1 parent cb55431 commit 381dc4b

File tree

3 files changed

+101
-18
lines changed

3 files changed

+101
-18
lines changed

dbos/_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
cast,
2020
)
2121

22-
import psycopg
23-
2422
from dbos._outcome import Immediate, NoResult, Outcome, Pending
2523
from dbos._utils import GlobalParams, retriable_postgres_exception
2624

@@ -831,10 +829,10 @@ def record_get_result(func: Callable[[], R]) -> R:
831829
return r
832830

833831
outcome = (
834-
wfOutcome.wrap(init_wf)
832+
wfOutcome.wrap(init_wf, dbos=dbos)
835833
.also(DBOSAssumeRole(rr))
836834
.also(enterWorkflowCtxMgr(attributes))
837-
.then(record_get_result)
835+
.then(record_get_result, dbos=dbos)
838836
)
839837
return outcome() # type: ignore
840838

@@ -1146,7 +1144,7 @@ def check_existing_result() -> Union[NoResult, R]:
11461144

11471145
outcome = (
11481146
stepOutcome.then(record_step_result)
1149-
.intercept(check_existing_result)
1147+
.intercept(check_existing_result, dbos=dbos)
11501148
.also(EnterDBOSStep(attributes))
11511149
)
11521150
return outcome()

dbos/_outcome.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,24 @@
22
import contextlib
33
import inspect
44
import time
5-
from typing import Any, Callable, Coroutine, Optional, Protocol, TypeVar, Union, cast
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Callable,
9+
Coroutine,
10+
Optional,
11+
Protocol,
12+
TypeVar,
13+
Union,
14+
cast,
15+
)
616

717
from dbos._context import EnterDBOSStepRetry
18+
from dbos._error import DBOSException
19+
from dbos._registrations import get_dbos_func_name
20+
21+
if TYPE_CHECKING:
22+
from ._dbos import DBOS
823

924
T = TypeVar("T")
1025
R = TypeVar("R")
@@ -24,10 +39,15 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoResult":
2439
class Outcome(Protocol[T]):
2540

2641
def wrap(
27-
self, before: Callable[[], Callable[[Callable[[], T]], R]]
42+
self,
43+
before: Callable[[], Callable[[Callable[[], T]], R]],
44+
*,
45+
dbos: Optional["DBOS"] = None,
2846
) -> "Outcome[R]": ...
2947

30-
def then(self, next: Callable[[Callable[[], T]], R]) -> "Outcome[R]": ...
48+
def then(
49+
self, next: Callable[[Callable[[], T]], R], *, dbos: Optional["DBOS"] = None
50+
) -> "Outcome[R]": ...
3151

3252
def also(
3353
self, cm: contextlib.AbstractContextManager[Any, bool]
@@ -41,7 +61,10 @@ def retry(
4161
) -> "Outcome[T]": ...
4262

4363
def intercept(
44-
self, interceptor: Callable[[], Union[NoResult, T]]
64+
self,
65+
interceptor: Callable[[], Union[NoResult, T]],
66+
*,
67+
dbos: Optional["DBOS"] = None,
4568
) -> "Outcome[T]": ...
4669

4770
def __call__(self) -> Union[T, Coroutine[Any, Any, T]]: ...
@@ -63,11 +86,17 @@ class Immediate(Outcome[T]):
6386
def __init__(self, func: Callable[[], T]):
6487
self._func = func
6588

66-
def then(self, next: Callable[[Callable[[], T]], R]) -> "Immediate[R]":
89+
def then(
90+
self,
91+
next: Callable[[Callable[[], T]], R],
92+
dbos: Optional["DBOS"] = None,
93+
) -> "Immediate[R]":
6794
return Immediate(lambda: next(self._func))
6895

6996
def wrap(
70-
self, before: Callable[[], Callable[[Callable[[], T]], R]]
97+
self,
98+
before: Callable[[], Callable[[Callable[[], T]], R]],
99+
dbos: Optional["DBOS"] = None,
71100
) -> "Immediate[R]":
72101
return Immediate(lambda: before()(self._func))
73102

@@ -79,7 +108,10 @@ def _intercept(
79108
return intercepted if not isinstance(intercepted, NoResult) else func()
80109

81110
def intercept(
82-
self, interceptor: Callable[[], Union[NoResult, T]]
111+
self,
112+
interceptor: Callable[[], Union[NoResult, T]],
113+
*,
114+
dbos: Optional["DBOS"] = None,
83115
) -> "Immediate[T]":
84116
return Immediate[T](lambda: Immediate._intercept(self._func, interceptor))
85117

@@ -142,7 +174,12 @@ def _raise(ex: BaseException) -> T:
142174
async def _wrap(
143175
func: Callable[[], Coroutine[Any, Any, T]],
144176
before: Callable[[], Callable[[Callable[[], T]], R]],
177+
*,
178+
dbos: Optional["DBOS"] = None,
145179
) -> R:
180+
# Make sure the executor pool is configured correctly
181+
if dbos is not None:
182+
await dbos._configure_asyncio_thread_pool()
146183
after = await asyncio.to_thread(before)
147184
try:
148185
value = await func()
@@ -151,12 +188,17 @@ async def _wrap(
151188
return await asyncio.to_thread(after, lambda: Pending._raise(exp))
152189

153190
def wrap(
154-
self, before: Callable[[], Callable[[Callable[[], T]], R]]
191+
self,
192+
before: Callable[[], Callable[[Callable[[], T]], R]],
193+
*,
194+
dbos: Optional["DBOS"] = None,
155195
) -> "Pending[R]":
156-
return Pending[R](lambda: Pending._wrap(self._func, before))
196+
return Pending[R](lambda: Pending._wrap(self._func, before, dbos=dbos))
157197

158-
def then(self, next: Callable[[Callable[[], T]], R]) -> "Pending[R]":
159-
return Pending[R](lambda: Pending._wrap(self._func, lambda: next))
198+
def then(
199+
self, next: Callable[[Callable[[], T]], R], *, dbos: Optional["DBOS"] = None
200+
) -> "Pending[R]":
201+
return Pending[R](lambda: Pending._wrap(self._func, lambda: next, dbos=dbos))
160202

161203
@staticmethod
162204
async def _also( # type: ignore
@@ -173,12 +215,24 @@ def also(self, cm: contextlib.AbstractContextManager[Any, bool]) -> "Pending[T]"
173215
async def _intercept(
174216
func: Callable[[], Coroutine[Any, Any, T]],
175217
interceptor: Callable[[], Union[NoResult, T]],
218+
*,
219+
dbos: Optional["DBOS"] = None,
176220
) -> T:
221+
# Make sure the executor pool is configured correctly
222+
if dbos is not None:
223+
await dbos._configure_asyncio_thread_pool()
177224
intercepted = await asyncio.to_thread(interceptor)
178225
return intercepted if not isinstance(intercepted, NoResult) else await func()
179226

180-
def intercept(self, interceptor: Callable[[], Union[NoResult, T]]) -> "Pending[T]":
181-
return Pending[T](lambda: Pending._intercept(self._func, interceptor))
227+
def intercept(
228+
self,
229+
interceptor: Callable[[], Union[NoResult, T]],
230+
*,
231+
dbos: Optional["DBOS"] = None,
232+
) -> "Pending[T]":
233+
return Pending[T](
234+
lambda: Pending._intercept(self._func, interceptor, dbos=dbos)
235+
)
182236

183237
@staticmethod
184238
async def _retry(

tests/test_dbos.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,12 +1250,43 @@ def test_workflow(var: str) -> str:
12501250
var = "test"
12511251
assert test_workflow(var) == var
12521252

1253+
# Start the workflow asynchornously
1254+
wf = dbos.start_workflow(test_workflow, var)
1255+
assert wf.get_result() == var
1256+
12531257
DBOS.destroy()
12541258
DBOS(config=config)
12551259
DBOS.launch()
12561260

12571261
assert test_workflow(var) == var
12581262

1263+
wf = dbos.start_workflow(test_workflow, var)
1264+
assert wf.get_result() == var
1265+
1266+
1267+
@pytest.mark.asyncio
1268+
async def test_destroy_semantics_async(dbos: DBOS, config: DBOSConfig) -> None:
1269+
1270+
@DBOS.workflow()
1271+
async def test_workflow(var: str) -> str:
1272+
return var
1273+
1274+
var = "test"
1275+
assert await test_workflow(var) == var
1276+
1277+
# Start the workflow asynchornously
1278+
wf = await dbos.start_workflow_async(test_workflow, var)
1279+
assert await wf.get_result() == var
1280+
1281+
DBOS.destroy()
1282+
DBOS(config=config)
1283+
DBOS.launch()
1284+
1285+
assert await test_workflow(var) == var
1286+
1287+
wf = await dbos.start_workflow_async(test_workflow, var)
1288+
assert await wf.get_result() == var
1289+
12591290

12601291
def test_double_decoration(dbos: DBOS) -> None:
12611292
with pytest.raises(

0 commit comments

Comments
 (0)