1111 Awaitable ,
1212 Callable ,
1313 Protocol ,
14+ ParamSpec ,
15+ TypeVar ,
16+ Type ,
17+ cast ,
1418 dataclass_transform ,
1519 Annotated ,
1620 get_args ,
1721)
1822
1923from . import _engine # type: ignore
24+ from .subprocess_exec import executor_stub
2025from .convert import (
2126 make_engine_value_encoder ,
2227 make_engine_value_decoder ,
@@ -85,11 +90,13 @@ class Executor(Protocol):
8590 op_category : OpCategory
8691
8792
88- def _load_spec_from_engine (spec_cls : type , spec : dict [str , Any ]) -> Any :
93+ def _load_spec_from_engine (
94+ spec_loader : Callable [..., Any ], spec : dict [str , Any ]
95+ ) -> Any :
8996 """
9097 Load a spec from the engine.
9198 """
92- return spec_cls (** spec )
99+ return spec_loader (** spec )
93100
94101
95102def _get_required_method (cls : type , name : str ) -> Callable [..., Any ]:
@@ -101,18 +108,18 @@ def _get_required_method(cls: type, name: str) -> Callable[..., Any]:
101108 return method
102109
103110
104- class _FunctionExecutorFactory :
105- _spec_cls : type
111+ class _EngineFunctionExecutorFactory :
112+ _spec_loader : Callable [..., Any ]
106113 _executor_cls : type
107114
108- def __init__ (self , spec_cls : type , executor_cls : type ):
109- self ._spec_cls = spec_cls
115+ def __init__ (self , spec_loader : Callable [..., Any ] , executor_cls : type ):
116+ self ._spec_loader = spec_loader
110117 self ._executor_cls = executor_cls
111118
112119 def __call__ (
113120 self , spec : dict [str , Any ], * args : Any , ** kwargs : Any
114121 ) -> tuple [dict [str , Any ], Executor ]:
115- spec = _load_spec_from_engine (self ._spec_cls , spec )
122+ spec = _load_spec_from_engine (self ._spec_loader , spec )
116123 executor = self ._executor_cls (spec )
117124 result_type = executor .analyze_schema (* args , ** kwargs )
118125 return (result_type , executor )
@@ -166,8 +173,8 @@ def _register_op_factory(
166173 category : OpCategory ,
167174 expected_args : list [tuple [str , inspect .Parameter ]],
168175 expected_return : Any ,
169- executor_cls : type ,
170- spec_cls : type ,
176+ executor_factory : Any ,
177+ spec_loader : Callable [..., Any ] ,
171178 op_kind : str ,
172179 op_args : OpArgs ,
173180) -> None :
@@ -179,15 +186,19 @@ class _WrappedExecutor:
179186 _executor : Any
180187 _args_info : list [_ArgInfo ]
181188 _kwargs_info : dict [str , _ArgInfo ]
182- _acall : Callable [..., Awaitable [Any ]]
183189 _result_encoder : Callable [[Any ], Any ]
190+ _acall : Callable [..., Awaitable [Any ]] | None = None
184191
185192 def __init__ (self , spec : Any ) -> None :
186- executor : Any = executor_class ()
187- executor = executor_cls ()
188- executor .spec = spec
193+ executor : Any
194+
195+ if op_args .gpu :
196+ executor = executor_stub (executor_factory , spec )
197+ else :
198+ executor = executor_factory ()
199+ executor .spec = spec
200+
189201 self ._executor = executor
190- self ._acall = _to_async_call (executor .__call__ )
191202
192203 def analyze_schema (
193204 self , * args : _engine .OpArgSchema , ** kwargs : _engine .OpArgSchema
@@ -293,7 +304,7 @@ def process_arg(
293304
294305 base_analyze_method = getattr (self ._executor , "analyze" , None )
295306 if base_analyze_method is not None :
296- result_type = base_analyze_method (* args , ** kwargs )
307+ result_type = base_analyze_method ()
297308 else :
298309 result_type = expected_return
299310 if len (attributes ) > 0 :
@@ -316,6 +327,7 @@ async def prepare(self) -> None:
316327 prepare_method = getattr (self ._executor , "prepare" , None )
317328 if prepare_method is not None :
318329 await _to_async_call (prepare_method )()
330+ self ._acall = _to_async_call (self ._executor .__call__ )
319331
320332 async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
321333 decoded_args = []
@@ -335,6 +347,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
335347 return None
336348 decoded_kwargs [kwarg_name ] = kwarg_info .decoder (arg )
337349
350+ assert self ._acall is not None
338351 if op_args .gpu :
339352 # For GPU executions, data-level parallelism is applied, so we don't want to
340353 # execute different tasks in parallel.
@@ -355,7 +368,7 @@ def behavior_version(self) -> int | None:
355368
356369 if category == OpCategory .FUNCTION :
357370 _engine .register_function_factory (
358- op_kind , _FunctionExecutorFactory ( spec_cls , _WrappedExecutor )
371+ op_kind , _EngineFunctionExecutorFactory ( spec_loader , _WrappedExecutor )
359372 )
360373 else :
361374 raise ValueError (f"Unsupported executor type { category } " )
@@ -381,8 +394,8 @@ def _inner(cls: type[Executor]) -> type:
381394 category = spec_cls ._op_category ,
382395 expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
383396 expected_return = sig .return_annotation ,
384- executor_cls = cls ,
385- spec_cls = spec_cls ,
397+ executor_factory = cls ,
398+ spec_loader = spec_cls ,
386399 op_kind = spec_cls .__name__ ,
387400 op_args = op_args ,
388401 )
@@ -395,6 +408,13 @@ class _EmptyFunctionSpec(FunctionSpec):
395408 pass
396409
397410
411+ class _SimpleFunctionExecutor :
412+ spec : Any
413+
414+ def prepare (self ) -> None :
415+ self .__call__ = self .spec .__call__
416+
417+
398418def function (** args : Any ) -> Callable [[Callable [..., Any ]], FunctionSpec ]:
399419 """
400420 Decorate a function to provide a function for an op.
@@ -405,29 +425,32 @@ def _inner(fn: Callable[..., Any]) -> FunctionSpec:
405425 # Convert snake case to camel case.
406426 op_name = "" .join (word .capitalize () for word in fn .__name__ .split ("_" ))
407427 sig = inspect .signature (fn )
428+ full_name = f"{ fn .__module__ } .{ fn .__qualname__ } "
408429
409- class _SpecExecutor ( _EmptyFunctionSpec ):
410- def __call__ ( self , * args : Any , ** kwargs : Any ) -> Any :
411- return fn ( * args , ** kwargs )
430+ # An object that is both callable and can act as a FunctionSpec.
431+ class _CallableSpec ( _EmptyFunctionSpec ) :
432+ __call__ = staticmethod ( fn )
412433
413434 def __reduce__ (self ) -> str | tuple [Any , ...]:
414- return fn . __qualname__
435+ return full_name
415436
416- _SpecExecutor .__name__ = op_name
417- _SpecExecutor .__doc__ = fn .__doc__
418- _SpecExecutor .__module__ = fn .__module__
437+ _CallableSpec .__name__ = op_name
438+ _CallableSpec .__doc__ = fn .__doc__
439+ _CallableSpec .__qualname__ = fn .__qualname__
440+ _CallableSpec .__module__ = fn .__module__
441+ callable_spec = _CallableSpec ()
419442
420443 _register_op_factory (
421444 category = OpCategory .FUNCTION ,
422445 expected_args = list (sig .parameters .items ()),
423446 expected_return = sig .return_annotation ,
424- executor_cls = _SpecExecutor ,
425- spec_cls = _EmptyFunctionSpec ,
447+ executor_factory = _SimpleFunctionExecutor ,
448+ spec_loader = lambda : callable_spec ,
426449 op_kind = op_name ,
427450 op_args = op_args ,
428451 )
429452
430- return _SpecExecutor ()
453+ return callable_spec
431454
432455 return _inner
433456
0 commit comments