@@ -132,15 +132,140 @@ def _make_engine_value_converter(
132132
133133_gpu_dispatch_lock = Lock ()
134134
135- def executor_class (gpu : bool = False , cache : bool = False , behavior_version : int | None = None ) -> Callable [[type ], type ]:
135+ @dataclasses .dataclass
136+ class OpArgs :
136137 """
137- Decorate a class to provide an executor for an op.
138+ - gpu: Whether the executor will be executed on GPU.
139+ - cache: Whether the executor will be cached.
140+ - behavior_version: The behavior version of the executor. Cache will be invalidated if it
141+ changes. Must be provided if `cache` is True.
142+ """
143+ gpu : bool = False
144+ cache : bool = False
145+ behavior_version : int | None = None
146+
147+ def _register_op_factory (
148+ category : OpCategory ,
149+ expected_args : list [tuple [str , inspect .Parameter ]],
150+ expected_return ,
151+ executor_cls : type ,
152+ spec_cls : type ,
153+ op_args : OpArgs ,
154+ ):
155+ """
156+ Register an op factory.
157+ """
158+ class _Fallback :
159+ def enable_cache (self ):
160+ return op_args .cache
161+
162+ def behavior_version (self ):
163+ return op_args .behavior_version
164+
165+ class _WrappedClass (executor_cls , _Fallback ):
166+ _args_converters : list [Callable [[Any ], Any ]]
167+ _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
168+
169+ def __init__ (self , spec ):
170+ super ().__init__ ()
171+ self .spec = spec
172+
173+ def analyze (self , * args , ** kwargs ):
174+ """
175+ Analyze the spec and arguments. In this phase, argument types should be validated.
176+ It should return the expected result type for the current op.
177+ """
178+ self ._args_converters = []
179+ self ._kwargs_converters = {}
180+
181+ # Match arguments with parameters.
182+ next_param_idx = 0
183+ for arg in args :
184+ if next_param_idx >= len (expected_args ):
185+ raise ValueError (
186+ f"Too many arguments passed in: { len (args )} > { len (expected_args )} " )
187+ arg_name , arg_param = expected_args [next_param_idx ]
188+ if arg_param .kind in (
189+ inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .VAR_KEYWORD ):
190+ raise ValueError (
191+ f"Too many positional arguments passed in: { len (args )} > { next_param_idx } " )
192+ self ._args_converters .append (
193+ _make_engine_value_converter (
194+ [arg_name ], arg .value_type ['type' ], arg_param .annotation ))
195+ if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
196+ next_param_idx += 1
197+
198+ expected_kwargs = expected_args [next_param_idx :]
199+
200+ for kwarg_name , kwarg in kwargs .items ():
201+ expected_arg = next (
202+ (arg for arg in expected_kwargs
203+ if (arg [0 ] == kwarg_name and arg [1 ].kind in (
204+ inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ))
205+ or arg [1 ].kind == inspect .Parameter .VAR_KEYWORD ),
206+ None )
207+ if expected_arg is None :
208+ raise ValueError (f"Unexpected keyword argument passed in: { kwarg_name } " )
209+ arg_param = expected_arg [1 ]
210+ self ._kwargs_converters [kwarg_name ] = _make_engine_value_converter (
211+ [kwarg_name ], kwarg .value_type ['type' ], arg_param .annotation )
212+
213+ missing_args = [name for (name , arg ) in expected_kwargs
214+ if arg .default is inspect .Parameter .empty
215+ and (arg .kind == inspect .Parameter .POSITIONAL_ONLY or
216+ (arg .kind in (inspect .Parameter .KEYWORD_ONLY ,
217+ inspect .Parameter .POSITIONAL_OR_KEYWORD )
218+ and name not in kwargs ))]
219+ if len (missing_args ) > 0 :
220+ raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
221+
222+ prepare_method = getattr (executor_cls , 'analyze' , None )
223+ if prepare_method is not None :
224+ return prepare_method (self , * args , ** kwargs )
225+ else :
226+ return expected_return
227+
228+ def prepare (self ):
229+ """
230+ Prepare for execution.
231+ It's executed after `analyze` and before any `__call__` execution.
232+ """
233+ setup_method = getattr (executor_cls , 'prepare' , None )
234+ if setup_method is not None :
235+ setup_method (self )
236+
237+ def __call__ (self , * args , ** kwargs ):
238+ converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
239+ converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg )
240+ for arg_name , arg in kwargs .items ()}
241+ if op_args .gpu :
242+ # For GPU executions, data-level parallelism is applied, so we don't want to
243+ # execute different tasks in parallel.
244+ # Besides, multiprocessing is more appropriate for pytorch.
245+ # For now, we use a lock to ensure only one task is executed at a time.
246+ # TODO: Implement multi-processing dispatching.
247+ with _gpu_dispatch_lock :
248+ output = super ().__call__ (* converted_args , ** converted_kwargs )
249+ else :
250+ output = super ().__call__ (* converted_args , ** converted_kwargs )
251+ return to_engine_value (output )
252+
253+ _WrappedClass .__name__ = executor_cls .__name__
138254
139- Args:
140- gpu: Whether the executor will be executed on GPU.
141- cache: Whether the executor will be cached.
142- behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True.
255+ if category == OpCategory .FUNCTION :
256+ _engine .register_function_factory (
257+ spec_cls .__name__ ,
258+ _FunctionExecutorFactory (spec_cls , _WrappedClass ))
259+ else :
260+ raise ValueError (f"Unsupported executor type { category } " )
261+
262+ return _WrappedClass
263+
264+ def executor_class (** args ) -> Callable [[type ], type ]:
143265 """
266+ Decorate a class to provide an executor for an op.
267+ """
268+ op_args = OpArgs (** args )
144269
145270 def _inner (cls : type [Executor ]) -> type :
146271 """
@@ -149,110 +274,46 @@ def _inner(cls: type[Executor]) -> type:
149274 type_hints = get_type_hints (cls )
150275 if 'spec' not in type_hints :
151276 raise TypeError ("Expect a `spec` field with type hint" )
152-
153277 spec_cls = type_hints ['spec' ]
154- op_name = spec_cls .__name__
155- category = spec_cls ._op_category
156-
157278 sig = inspect .signature (cls .__call__ )
158- expected_args = list (sig .parameters .items ())[1 :] # First argument is `self`
159- expected_return = sig .return_annotation
160-
161- cls_type : type = cls
162-
163- class _Fallback :
164- def enable_cache (self ):
165- return cache
166-
167- def behavior_version (self ):
168- return behavior_version
169-
170- class _WrappedClass (cls_type , _Fallback ):
171- _args_converters : list [Callable [[Any ], Any ]]
172- _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
173-
174- def __init__ (self , spec ):
175- super ().__init__ ()
176- self .spec = spec
177-
178- def analyze (self , * args , ** kwargs ):
179- """
180- Analyze the spec and arguments. In this phase, argument types should be validated.
181- It should return the expected result type for the current op.
182- """
183- self ._args_converters = []
184- self ._kwargs_converters = {}
185-
186- # Match arguments with parameters.
187- next_param_idx = 0
188- for arg in args :
189- if next_param_idx >= len (expected_args ):
190- raise ValueError (f"Too many arguments passed in: { len (args )} > { len (expected_args )} " )
191- arg_name , arg_param = expected_args [next_param_idx ]
192- if arg_param .kind == inspect .Parameter .KEYWORD_ONLY or arg_param .kind == inspect .Parameter .VAR_KEYWORD :
193- raise ValueError (f"Too many positional arguments passed in: { len (args )} > { next_param_idx } " )
194- self ._args_converters .append (
195- _make_engine_value_converter ([arg_name ], arg .value_type ['type' ], arg_param .annotation ))
196- if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
197- next_param_idx += 1
198-
199- expected_kwargs = expected_args [next_param_idx :]
200-
201- for kwarg_name , kwarg in kwargs .items ():
202- expected_arg = next (
203- (arg for arg in expected_kwargs
204- if (arg [0 ] == kwarg_name and arg [1 ].kind in (inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ))
205- or arg [1 ].kind == inspect .Parameter .VAR_KEYWORD ),
206- None )
207- if expected_arg is None :
208- raise ValueError (f"Unexpected keyword argument passed in: { kwarg_name } " )
209- arg_param = expected_arg [1 ]
210- self ._kwargs_converters [kwarg_name ] = _make_engine_value_converter (
211- [kwarg_name ], kwarg .value_type ['type' ], arg_param .annotation )
212-
213- missing_args = [name for (name , arg ) in expected_kwargs
214- if arg .default is inspect .Parameter .empty
215- and (arg .kind == inspect .Parameter .POSITIONAL_ONLY or
216- (arg .kind in (inspect .Parameter .KEYWORD_ONLY , inspect .Parameter .POSITIONAL_OR_KEYWORD ) and name not in kwargs ))]
217- if len (missing_args ) > 0 :
218- raise ValueError (f"Missing arguments: { ', ' .join (missing_args )} " )
219-
220- prepare_method = getattr (cls_type , 'analyze' , None )
221- if prepare_method is not None :
222- return prepare_method (self , * args , ** kwargs )
223- else :
224- return expected_return
225-
226- def prepare (self ):
227- """
228- Prepare for execution.
229- It's executed after `analyze` and before any `__call__` execution.
230- """
231- setup_method = getattr (cls_type , 'prepare' , None )
232- if setup_method is not None :
233- setup_method (self )
279+ return _register_op_factory (
280+ category = spec_cls ._op_category ,
281+ expected_args = list (sig .parameters .items ())[1 :], # First argument is `self`
282+ expected_return = sig .return_annotation ,
283+ executor_cls = cls ,
284+ spec_cls = spec_cls ,
285+ op_args = op_args )
286+
287+ return _inner
288+
289+ def function (** args ) -> Callable [[Callable ], FunctionSpec ]:
290+ """
291+ Decorate a function to provide a function for an op.
292+ """
293+ op_args = OpArgs (** args )
294+
295+ def _inner (fn : Callable ) -> FunctionSpec :
296+
297+ # Convert snake case to camel case.
298+ op_name = '' .join (word .capitalize () for word in fn .__name__ .split ('_' ))
299+ sig = inspect .signature (fn )
234300
301+ class _Executor :
235302 def __call__ (self , * args , ** kwargs ):
236- converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
237- converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg ) for arg_name , arg in kwargs .items ()}
238- if gpu :
239- # For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
240- # Besides, multiprocessing is more appropriate for pytorch.
241- # For now, we use a lock to ensure only one task is executed at a time.
242- # TODO: Implement multi-processing dispatching.
243- with _gpu_dispatch_lock :
244- output = super ().__call__ (* converted_args , ** converted_kwargs )
245- else :
246- output = super ().__call__ (* converted_args , ** converted_kwargs )
247- return to_engine_value (output )
303+ return fn (* args , ** kwargs )
248304
249- _WrappedClass .__name__ = cls .__name__
305+ class _Spec (FunctionSpec ):
306+ pass
307+ _Spec .__name__ = op_name
250308
251- if category == OpCategory .FUNCTION :
252- _engine .register_function_factory (op_name , _FunctionExecutorFactory (spec_cls , _WrappedClass ))
253- else :
254- raise ValueError (f"Unsupported executor type { category } " )
309+ _register_op_factory (
310+ category = OpCategory .FUNCTION ,
311+ expected_args = list (sig .parameters .items ()),
312+ expected_return = sig .return_annotation ,
313+ executor_cls = _Executor ,
314+ spec_cls = _Spec ,
315+ op_args = op_args )
255316
256- return _WrappedClass
317+ return _Spec ()
257318
258319 return _inner
0 commit comments