1111 Awaitable ,
1212 Callable ,
1313 Protocol ,
14+ ParamSpec ,
15+ TypeVar ,
16+ Type ,
17+ cast ,
1418 dataclass_transform ,
1519 Annotated ,
1620 get_args ,
1721)
1822
1923from . import _engine # type: ignore
24+ from .subprocess_exec import executor_stub
2025from .convert import (
2126 make_engine_value_encoder ,
2227 make_engine_value_decoder ,
@@ -85,11 +90,13 @@ class Executor(Protocol):
8590 op_category : OpCategory
8691
8792
88- def _load_spec_from_engine (spec_cls : type , spec : dict [str , Any ]) -> Any :
93+ def _load_spec_from_engine (
94+ spec_loader : Callable [..., Any ], spec : dict [str , Any ]
95+ ) -> Any :
8996 """
9097 Load a spec from the engine.
9198 """
92- return spec_cls (** spec )
99+ return spec_loader (** spec )
93100
94101
95102def _get_required_method (cls : type , name : str ) -> Callable [..., Any ]:
@@ -101,18 +108,18 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
101108 return method
102109
103110
104- class _FunctionExecutorFactory :
105- _spec_cls : type
111+ class _EngineFunctionExecutorFactory :
112+ _spec_loader : Callable [..., Any ]
106113 _executor_cls : type
107114
108- def __init__ (self , spec_cls : type , executor_cls : type ):
109- self ._spec_cls = spec_cls
115+ def __init__ (self , spec_loader : Callable [..., Any ] , executor_cls : type ):
116+ self ._spec_loader = spec_loader
110117 self ._executor_cls = executor_cls
111118
112119 def __call__ (
113120 self , spec : dict [str , Any ], * args : Any , ** kwargs : Any
114121 ) -> tuple [dict [str , Any ], Executor ]:
115- spec = _load_spec_from_engine (self ._spec_cls , spec )
122+ spec = _load_spec_from_engine (self ._spec_loader , spec )
116123 executor = self ._executor_cls (spec )
117124 result_type = executor .analyze_schema (* args , ** kwargs )
118125 return (result_type , executor )
@@ -166,31 +173,32 @@ def _register_op_factory(
166173 category : OpCategory ,
167174 expected_args : list [tuple [str , inspect .Parameter ]],
168175 expected_return : Any ,
169- executor_cls : type ,
170- spec_cls : type ,
176+ executor_factory : Any ,
177+ spec_loader : Callable [..., Any ],
178+ op_kind : str ,
171179 op_args : OpArgs ,
172- ) -> type :
180+ ) -> None :
173181 """
174182 Register an op factory.
175183 """
176184
177- class _Fallback :
178- def enable_cache (self ) -> bool :
179- return op_args .cache
180-
181- def behavior_version (self ) -> int | None :
182- return op_args .behavior_version
183-
184- class _WrappedClass (executor_cls , _Fallback ): # type: ignore[misc]
185+ class _WrappedExecutor :
186+ _executor : Any
185187 _args_info : list [_ArgInfo ]
186188 _kwargs_info : dict [str , _ArgInfo ]
187- _acall : Callable [..., Awaitable [Any ]]
188189 _result_encoder : Callable [[Any ], Any ]
190+ _acall : Callable [..., Awaitable [Any ]] | None = None
189191
190192 def __init__ (self , spec : Any ) -> None :
191- super ().__init__ ()
192- self .spec = spec
193- self ._acall = _to_async_call (super ().__call__ )
193+ executor : Any
194+
195+ if op_args .gpu :
196+ executor = executor_stub (executor_factory , spec )
197+ else :
198+ executor = executor_factory ()
199+ executor .spec = spec
200+
201+ self ._executor = executor
194202
195203 def analyze_schema (
196204 self , * args : _engine .OpArgSchema , ** kwargs : _engine .OpArgSchema
@@ -294,9 +302,9 @@ def process_arg(
294302 if len (missing_args ) > 0 :
295303 raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
296304
297- base_analyze_method = getattr (self , "analyze" , None )
305+ base_analyze_method = getattr (self . _executor , "analyze" , None )
298306 if base_analyze_method is not None :
299- result_type = base_analyze_method (* args , ** kwargs )
307+ result_type = base_analyze_method ()
300308 else :
301309 result_type = expected_return
302310 if len (attributes ) > 0 :
@@ -316,9 +324,10 @@ async def prepare(self) -> None:
316324 Prepare for execution.
317325 It's executed after `analyze` and before any `__call__` execution.
318326 """
319- prepare_method = getattr (super () , "prepare" , None )
327+ prepare_method = getattr (self . _executor , "prepare" , None )
320328 if prepare_method is not None :
321329 await _to_async_call (prepare_method )()
330+ self ._acall = _to_async_call (self ._executor .__call__ )
322331
323332 async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
324333 decoded_args = []
@@ -338,6 +347,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
338347 return None
339348 decoded_kwargs [kwarg_name ] = kwarg_info .decoder (arg )
340349
350+ assert self ._acall is not None
341351 if op_args .gpu :
342352 # For GPU executions, data-level parallelism is applied, so we don't want to
343353 # execute different tasks in parallel.
@@ -350,21 +360,19 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
350360 output = await self ._acall (* decoded_args , ** decoded_kwargs )
351361 return self ._result_encoder (output )
352362
353- _WrappedClass . __name__ = executor_cls . __name__
354- _WrappedClass . __doc__ = executor_cls . __doc__
355- _WrappedClass . __module__ = executor_cls . __module__
356- _WrappedClass . __qualname__ = executor_cls . __qualname__
357- _WrappedClass . __wrapped__ = executor_cls
363+ def enable_cache ( self ) -> bool :
364+ return op_args . cache
365+
366+ def behavior_version ( self ) -> int | None :
367+ return op_args . behavior_version
358368
359369 if category == OpCategory .FUNCTION :
360370 _engine .register_function_factory (
361- spec_cls . __name__ , _FunctionExecutorFactory ( spec_cls , _WrappedClass )
371+ op_kind , _EngineFunctionExecutorFactory ( spec_loader , _WrappedExecutor )
362372 )
363373 else :
364374 raise ValueError (f"Unsupported executor type { category } " )
365375
366- return _WrappedClass
367-
368376
369377def executor_class (** args : Any ) -> Callable [[type ], type ]:
370378 """
@@ -382,18 +390,31 @@ def _inner(cls: type[Executor]) -> type:
382390 raise TypeError ("Expect a `spec` field with type hint" )
383391 spec_cls = resolve_forward_ref (type_hints ["spec" ])
384392 sig = inspect .signature (cls .__call__ )
385- return _register_op_factory (
393+ _register_op_factory (
386394 category = spec_cls ._op_category ,
387395 expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
388396 expected_return = sig .return_annotation ,
389- executor_cls = cls ,
390- spec_cls = spec_cls ,
397+ executor_factory = cls ,
398+ spec_loader = spec_cls ,
399+ op_kind = spec_cls .__name__ ,
391400 op_args = op_args ,
392401 )
402+ return cls
393403
394404 return _inner
395405
396406
407+ class _EmptyFunctionSpec (FunctionSpec ):
408+ pass
409+
410+
411+ class _SimpleFunctionExecutor :
412+ spec : Any
413+
414+ def prepare (self ) -> None :
415+ self .__call__ = self .spec .__call__
416+
417+
397418def function (** args : Any ) -> Callable [[Callable [..., Any ]], FunctionSpec ]:
398419 """
399420 Decorate a function to provide a function for an op.
@@ -404,30 +425,32 @@ def _inner(fn: Callable[..., Any]) -> FunctionSpec:
404425 # Convert snake case to camel case.
405426 op_name = "" .join (word .capitalize () for word in fn .__name__ .split ("_" ))
406427 sig = inspect .signature (fn )
428+ full_name = f"{ fn .__module__ } .{ fn .__qualname__ } "
407429
408- class _Executor :
409- def __call__ ( self , * args : Any , ** kwargs : Any ) -> Any :
410- return fn ( * args , ** kwargs )
430+ # An object that is both callable and can act as a FunctionSpec.
431+ class _CallableSpec ( _EmptyFunctionSpec ) :
432+ __call__ = staticmethod ( fn )
411433
412- class _Spec (FunctionSpec ):
413- def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
414- return fn (* args , ** kwargs )
434+ def __reduce__ (self ) -> str | tuple [Any , ...]:
435+ return full_name
415436
416- _Spec .__name__ = op_name
417- _Spec .__doc__ = fn .__doc__
418- _Spec .__module__ = fn .__module__
419- _Spec .__qualname__ = fn .__qualname__
437+ _CallableSpec .__name__ = op_name
438+ _CallableSpec .__doc__ = fn .__doc__
439+ _CallableSpec .__qualname__ = fn .__qualname__
440+ _CallableSpec .__module__ = fn .__module__
441+ callable_spec = _CallableSpec ()
420442
421443 _register_op_factory (
422444 category = OpCategory .FUNCTION ,
423445 expected_args = list (sig .parameters .items ()),
424446 expected_return = sig .return_annotation ,
425- executor_cls = _Executor ,
426- spec_cls = _Spec ,
447+ executor_factory = _SimpleFunctionExecutor ,
448+ spec_loader = lambda : callable_spec ,
449+ op_kind = op_name ,
427450 op_args = op_args ,
428451 )
429452
430- return _Spec ()
453+ return callable_spec
431454
432455 return _inner
433456
0 commit comments