22import contextlib
33import inspect
44import time
5- from typing import Any , Callable , Coroutine , Optional , Protocol , TypeVar , Union , cast
5+ from typing import (
6+ TYPE_CHECKING ,
7+ Any ,
8+ Callable ,
9+ Coroutine ,
10+ Optional ,
11+ Protocol ,
12+ TypeVar ,
13+ Union ,
14+ cast ,
15+ )
616
717from dbos ._context import EnterDBOSStepRetry
18+ from dbos ._error import DBOSException
19+ from dbos ._registrations import get_dbos_func_name
20+
21+ if TYPE_CHECKING :
22+ from ._dbos import DBOS
823
924T = TypeVar ("T" )
1025R = TypeVar ("R" )
@@ -24,10 +39,15 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoResult":
2439class Outcome (Protocol [T ]):
2540
2641 def wrap (
27- self , before : Callable [[], Callable [[Callable [[], T ]], R ]]
42+ self ,
43+ before : Callable [[], Callable [[Callable [[], T ]], R ]],
44+ * ,
45+ dbos : Optional ["DBOS" ] = None ,
2846 ) -> "Outcome[R]" : ...
2947
30- def then (self , next : Callable [[Callable [[], T ]], R ]) -> "Outcome[R]" : ...
48+ def then (
49+ self , next : Callable [[Callable [[], T ]], R ], * , dbos : Optional ["DBOS" ] = None
50+ ) -> "Outcome[R]" : ...
3151
3252 def also (
3353 self , cm : contextlib .AbstractContextManager [Any , bool ]
@@ -41,7 +61,10 @@ def retry(
4161 ) -> "Outcome[T]" : ...
4262
4363 def intercept (
44- self , interceptor : Callable [[], Union [NoResult , T ]]
64+ self ,
65+ interceptor : Callable [[], Union [NoResult , T ]],
66+ * ,
67+ dbos : Optional ["DBOS" ] = None ,
4568 ) -> "Outcome[T]" : ...
4669
4770 def __call__ (self ) -> Union [T , Coroutine [Any , Any , T ]]: ...
@@ -63,11 +86,17 @@ class Immediate(Outcome[T]):
6386 def __init__ (self , func : Callable [[], T ]):
6487 self ._func = func
6588
66- def then (self , next : Callable [[Callable [[], T ]], R ]) -> "Immediate[R]" :
89+ def then (
90+ self ,
91+ next : Callable [[Callable [[], T ]], R ],
92+ dbos : Optional ["DBOS" ] = None ,
93+ ) -> "Immediate[R]" :
6794 return Immediate (lambda : next (self ._func ))
6895
6996 def wrap (
70- self , before : Callable [[], Callable [[Callable [[], T ]], R ]]
97+ self ,
98+ before : Callable [[], Callable [[Callable [[], T ]], R ]],
99+ dbos : Optional ["DBOS" ] = None ,
71100 ) -> "Immediate[R]" :
72101 return Immediate (lambda : before ()(self ._func ))
73102
@@ -79,7 +108,10 @@ def _intercept(
79108 return intercepted if not isinstance (intercepted , NoResult ) else func ()
80109
81110 def intercept (
82- self , interceptor : Callable [[], Union [NoResult , T ]]
111+ self ,
112+ interceptor : Callable [[], Union [NoResult , T ]],
113+ * ,
114+ dbos : Optional ["DBOS" ] = None ,
83115 ) -> "Immediate[T]" :
84116 return Immediate [T ](lambda : Immediate ._intercept (self ._func , interceptor ))
85117
@@ -142,7 +174,12 @@ def _raise(ex: BaseException) -> T:
142174 async def _wrap (
143175 func : Callable [[], Coroutine [Any , Any , T ]],
144176 before : Callable [[], Callable [[Callable [[], T ]], R ]],
177+ * ,
178+ dbos : Optional ["DBOS" ] = None ,
145179 ) -> R :
180+ # Make sure the executor pool is configured correctly
181+ if dbos is not None :
182+ await dbos ._configure_asyncio_thread_pool ()
146183 after = await asyncio .to_thread (before )
147184 try :
148185 value = await func ()
@@ -151,12 +188,17 @@ async def _wrap(
151188 return await asyncio .to_thread (after , lambda : Pending ._raise (exp ))
152189
153190 def wrap (
154- self , before : Callable [[], Callable [[Callable [[], T ]], R ]]
191+ self ,
192+ before : Callable [[], Callable [[Callable [[], T ]], R ]],
193+ * ,
194+ dbos : Optional ["DBOS" ] = None ,
155195 ) -> "Pending[R]" :
156- return Pending [R ](lambda : Pending ._wrap (self ._func , before ))
196+ return Pending [R ](lambda : Pending ._wrap (self ._func , before , dbos = dbos ))
157197
158- def then (self , next : Callable [[Callable [[], T ]], R ]) -> "Pending[R]" :
159- return Pending [R ](lambda : Pending ._wrap (self ._func , lambda : next ))
198+ def then (
199+ self , next : Callable [[Callable [[], T ]], R ], * , dbos : Optional ["DBOS" ] = None
200+ ) -> "Pending[R]" :
201+ return Pending [R ](lambda : Pending ._wrap (self ._func , lambda : next , dbos = dbos ))
160202
161203 @staticmethod
162204 async def _also ( # type: ignore
@@ -173,12 +215,24 @@ def also(self, cm: contextlib.AbstractContextManager[Any, bool]) -> "Pending[T]"
173215 async def _intercept (
174216 func : Callable [[], Coroutine [Any , Any , T ]],
175217 interceptor : Callable [[], Union [NoResult , T ]],
218+ * ,
219+ dbos : Optional ["DBOS" ] = None ,
176220 ) -> T :
221+ # Make sure the executor pool is configured correctly
222+ if dbos is not None :
223+ await dbos ._configure_asyncio_thread_pool ()
177224 intercepted = await asyncio .to_thread (interceptor )
178225 return intercepted if not isinstance (intercepted , NoResult ) else await func ()
179226
180- def intercept (self , interceptor : Callable [[], Union [NoResult , T ]]) -> "Pending[T]" :
181- return Pending [T ](lambda : Pending ._intercept (self ._func , interceptor ))
227+ def intercept (
228+ self ,
229+ interceptor : Callable [[], Union [NoResult , T ]],
230+ * ,
231+ dbos : Optional ["DBOS" ] = None ,
232+ ) -> "Pending[T]" :
233+ return Pending [T ](
234+ lambda : Pending ._intercept (self ._func , interceptor , dbos = dbos )
235+ )
182236
183237 @staticmethod
184238 async def _retry (
0 commit comments