55import dataclasses
66import inspect
77
8- from typing import get_type_hints , Protocol , Any , Callable , dataclass_transform
8+ from typing import get_type_hints , Protocol , Any , Callable , Awaitable , dataclass_transform
99from enum import Enum
10- from threading import Lock
10+ from functools import partial
1111
1212from .typing import encode_enriched_type
1313from .convert import to_engine_value , make_engine_value_converter
@@ -61,7 +61,7 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
6161 return (encode_enriched_type (result_type ), executor )
6262
6363
64- _gpu_dispatch_lock = Lock ()
64+ _gpu_dispatch_lock = asyncio . Lock ()
6565
6666@dataclasses .dataclass
6767class OpArgs :
@@ -75,11 +75,15 @@ class OpArgs:
7575 cache : bool = False
7676 behavior_version : int | None = None
7777
78+ def _to_async_call (call : Callable ) -> Callable [..., Awaitable [Any ]]:
79+ if inspect .iscoroutinefunction (call ):
80+ return call
81+ return lambda * args , ** kwargs : asyncio .to_thread (lambda : call (* args , ** kwargs ))
82+
7883def _register_op_factory (
7984 category : OpCategory ,
8085 expected_args : list [tuple [str , inspect .Parameter ]],
8186 expected_return ,
82- is_async : bool ,
8387 executor_cls : type ,
8488 spec_cls : type ,
8589 op_args : OpArgs ,
@@ -97,10 +101,12 @@ def behavior_version(self):
97101 class _WrappedClass (executor_cls , _Fallback ):
98102 _args_converters : list [Callable [[Any ], Any ]]
99103 _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
104+ _acall : Callable
100105
101106 def __init__ (self , spec ):
102107 super ().__init__ ()
103108 self .spec = spec
109+ self ._acall = _to_async_call (super ().__call__ )
104110
105111 def analyze (self , * args , ** kwargs ):
106112 """
@@ -157,42 +163,30 @@ def analyze(self, *args, **kwargs):
157163 else :
158164 return expected_return
159165
160- def prepare (self ):
166+ async def prepare (self ):
161167 """
162168 Prepare for execution.
163169 It's executed after `analyze` and before any `__call__` execution.
164170 """
165- setup_method = getattr (executor_cls , 'prepare' , None )
171+ setup_method = getattr (super () , 'prepare' , None )
166172 if setup_method is not None :
167- setup_method ( self )
173+ await _to_async_call ( setup_method )( )
168174
169- def __call__ (self , * args , ** kwargs ):
175+ async def __call__ (self , * args , ** kwargs ):
170176 converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
171177 converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg )
172178 for arg_name , arg in kwargs .items ()}
173- if is_async :
174- async def _inner ():
175- if op_args .gpu :
176- await asyncio .to_thread (_gpu_dispatch_lock .acquire )
177- try :
178- output = await super (_WrappedClass , self ).__call__ (
179- * converted_args , ** converted_kwargs )
180- finally :
181- if op_args .gpu :
182- _gpu_dispatch_lock .release ()
183- return to_engine_value (output )
184- return _inner ()
185179
186180 if op_args .gpu :
187181 # For GPU executions, data-level parallelism is applied, so we don't want to
188182 # execute different tasks in parallel.
189183 # Besides, multiprocessing is more appropriate for pytorch.
190184 # For now, we use a lock to ensure only one task is executed at a time.
191185 # TODO: Implement multi-processing dispatching.
192- with _gpu_dispatch_lock :
193- output = super (). __call__ (* converted_args , ** converted_kwargs )
186+ async with _gpu_dispatch_lock :
187+ output = await self . _acall (* converted_args , ** converted_kwargs )
194188 else :
195- output = super (). __call__ (* converted_args , ** converted_kwargs )
189+ output = await self . _acall (* converted_args , ** converted_kwargs )
196190 return to_engine_value (output )
197191
198192 _WrappedClass .__name__ = executor_cls .__name__
@@ -203,9 +197,7 @@ async def _inner():
203197
204198 if category == OpCategory .FUNCTION :
205199 _engine .register_function_factory (
206- spec_cls .__name__ ,
207- _FunctionExecutorFactory (spec_cls , _WrappedClass ),
208- is_async )
200+ spec_cls .__name__ , _FunctionExecutorFactory (spec_cls , _WrappedClass ))
209201 else :
210202 raise ValueError (f"Unsupported executor type { category } " )
211203
@@ -230,7 +222,6 @@ def _inner(cls: type[Executor]) -> type:
230222 category = spec_cls ._op_category ,
231223 expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
232224 expected_return = sig .return_annotation ,
233- is_async = inspect .iscoroutinefunction (cls .__call__ ),
234225 executor_cls = cls ,
235226 spec_cls = spec_cls ,
236227 op_args = op_args )
@@ -266,7 +257,6 @@ class _Spec(FunctionSpec):
266257 category = OpCategory .FUNCTION ,
267258 expected_args = list (sig .parameters .items ()),
268259 expected_return = sig .return_annotation ,
269- is_async = inspect .iscoroutinefunction (fn ),
270260 executor_cls = _Executor ,
271261 spec_cls = _Spec ,
272262 op_args = op_args )
0 commit comments