88from enum import Enum
99from threading import Lock
1010
11- from .typing import encode_type
11+ from .typing import encode_enriched_type , analyze_type_info , COLLECTION_TYPES
1212from . import _engine
1313
1414
@@ -57,16 +57,86 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
5757 spec = self ._spec_cls (** spec )
5858 executor = self ._executor_cls (spec )
5959 result_type = executor .analyze (* args , ** kwargs )
60- return (encode_type (result_type ), executor )
60+ return (encode_enriched_type (result_type ), executor )
6161
6262def _to_engine_value (value : Any ) -> Any :
6363 """Convert a Python value to an engine value."""
6464 if dataclasses .is_dataclass (value ):
6565 return [_to_engine_value (getattr (value , f .name )) for f in dataclasses .fields (value )]
66- elif isinstance (value , list ) or isinstance ( value , tuple ):
66+ if isinstance (value , ( list , tuple ) ):
6767 return [_to_engine_value (v ) for v in value ]
6868 return value
6969
70+ def _make_engine_struct_value_converter (
71+ field_path : list [str ],
72+ src_fields : list [dict [str , Any ]],
73+ dst_dataclass_type : type ,
74+ ) -> Callable [[list ], Any ]:
75+ """Make a converter from an engine field values to a Python value."""
76+
77+ src_name_to_idx = {f ['name' ]: i for i , f in enumerate (src_fields )}
78+ def make_closure_for_value (name : str , param : inspect .Parameter ) -> Callable [[list ], Any ]:
79+ src_idx = src_name_to_idx .get (name )
80+ if src_idx is not None :
81+ field_path .append (f'.{ name } ' )
82+ field_converter = _make_engine_value_converter (
83+ field_path , src_fields [src_idx ]['type' ], param .annotation )
84+ field_path .pop ()
85+ return lambda values : field_converter (values [src_idx ])
86+
87+ default_value = param .default
88+ if default_value is inspect .Parameter .empty :
89+ raise ValueError (
90+ f"Field without default value is missing in input: { '' .join (field_path )} " )
91+
92+ return lambda _ : default_value
93+
94+ field_value_converters = [
95+ make_closure_for_value (name , param )
96+ for (name , param ) in inspect .signature (dst_dataclass_type ).parameters .items ()]
97+
98+ return lambda values : dst_dataclass_type (
99+ * (converter (values ) for converter in field_value_converters ))
100+
101+ def _make_engine_value_converter (
102+ field_path : list [str ],
103+ src_type : dict [str , Any ],
104+ dst_annotation ,
105+ ) -> Callable [[Any ], Any ]:
106+ """Make a converter from an engine value to a Python value."""
107+
108+ src_type_kind = src_type ['kind' ]
109+
110+ if dst_annotation is inspect .Parameter .empty :
111+ if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES :
112+ raise ValueError (f"Missing type annotation for `{ '' .join (field_path )} `."
113+ f"It's required for { src_type_kind } type." )
114+ return lambda value : value
115+
116+ dst_type_info = analyze_type_info (dst_annotation )
117+
118+ if src_type_kind != dst_type_info .kind :
119+ raise ValueError (
120+ f"Type mismatch for `{ '' .join (field_path )} `: "
121+ f"passed in { src_type_kind } , declared { dst_annotation } ({ dst_type_info .kind } )" )
122+
123+ if dst_type_info .dataclass_type is not None :
124+ return _make_engine_struct_value_converter (
125+ field_path , src_type ['fields' ], dst_type_info .dataclass_type )
126+
127+ if src_type_kind in COLLECTION_TYPES :
128+ field_path .append ('[*]' )
129+ elem_type_info = analyze_type_info (dst_type_info .elem_type )
130+ if elem_type_info .dataclass_type is None :
131+ raise ValueError (f"Type mismatch for `{ '' .join (field_path )} `: "
132+ f"declared `{ dst_type_info .kind } `, a dataclass type expected" )
133+ elem_converter = _make_engine_struct_value_converter (
134+ field_path , src_type ['row' ]['fields' ], elem_type_info .dataclass_type )
135+ field_path .pop ()
136+ return lambda value : [elem_converter (v ) for v in value ] if value is not None else None
137+
138+ return lambda value : value
139+
70140_gpu_dispatch_lock = Lock ()
71141
72142def executor_class (gpu : bool = False , cache : bool = False , behavior_version : int | None = None ) -> Callable [[type ], type ]:
@@ -105,6 +175,9 @@ def behavior_version(self):
105175 return behavior_version
106176
107177 class _WrappedClass (cls_type , _Fallback ):
178+ _args_converters : list [Callable [[Any ], Any ]]
179+ _kwargs_converters : dict [str , Callable [[str , Any ], Any ]]
180+
108181 def __init__ (self , spec ):
109182 super ().__init__ ()
110183 self .spec = spec
@@ -114,16 +187,19 @@ def analyze(self, *args, **kwargs):
114187 Analyze the spec and arguments. In this phase, argument types should be validated.
115188 It should return the expected result type for the current op.
116189 """
190+ self ._args_converters = []
191+ self ._kwargs_converters = {}
192+
117193 # Match arguments with parameters.
118194 next_param_idx = 0
119- for arg in args :
195+ for arg in args :
120196 if next_param_idx >= len (expected_args ):
121- raise ValueError (f"Too many arguments: { len (args )} > { len (expected_args )} " )
197+ raise ValueError (f"Too many arguments passed in : { len (args )} > { len (expected_args )} " )
122198 arg_name , arg_param = expected_args [next_param_idx ]
123199 if arg_param .kind == inspect .Parameter .KEYWORD_ONLY or arg_param .kind == inspect .Parameter .VAR_KEYWORD :
124- raise ValueError (f"Too many positional arguments: { len (args )} > { next_param_idx } " )
125- if arg_param . annotation is not inspect . Parameter . empty :
126- arg . validate_arg ( arg_name , encode_type ( arg_param .annotation ))
200+ raise ValueError (f"Too many positional arguments passed in : { len (args )} > { next_param_idx } " )
201+ self . _args_converters . append (
202+ _make_engine_value_converter ([ arg_name ], arg . value_type [ 'type' ], arg_param .annotation ))
127203 if arg_param .kind != inspect .Parameter .VAR_POSITIONAL :
128204 next_param_idx += 1
129205
@@ -136,10 +212,10 @@ def analyze(self, *args, **kwargs):
136212 or arg [1 ].kind == inspect .Parameter .VAR_KEYWORD ),
137213 None )
138214 if expected_arg is None :
139- raise ValueError (f"Unexpected keyword argument: { kwarg_name } " )
215+ raise ValueError (f"Unexpected keyword argument passed in : { kwarg_name } " )
140216 arg_param = expected_arg [1 ]
141- if arg_param . annotation is not inspect . Parameter . empty :
142- kwarg .validate_arg ( kwarg_name , encode_type ( arg_param .annotation ) )
217+ self . _kwargs_converters [ kwarg_name ] = _make_engine_value_converter (
218+ [ kwarg_name ], kwarg .value_type [ 'type' ], arg_param .annotation )
143219
144220 missing_args = [name for (name , arg ) in expected_kwargs
145221 if arg .default is inspect .Parameter .empty
@@ -164,15 +240,17 @@ def prepare(self):
164240 setup_method (self )
165241
166242 def __call__ (self , * args , ** kwargs ):
243+ converted_args = (converter (arg ) for converter , arg in zip (self ._args_converters , args ))
244+ converted_kwargs = {arg_name : self ._kwargs_converters [arg_name ](arg ) for arg_name , arg in kwargs .items ()}
167245 if gpu :
168246 # For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
169247 # Besides, multiprocessing is more appropriate for pytorch.
170248 # For now, we use a lock to ensure only one task is executed at a time.
171249 # TODO: Implement multi-processing dispatching.
172250 with _gpu_dispatch_lock :
173- output = super ().__call__ (* args , ** kwargs )
251+ output = super ().__call__ (* converted_args , ** converted_kwargs )
174252 else :
175- output = super ().__call__ (* args , ** kwargs )
253+ output = super ().__call__ (* converted_args , ** converted_kwargs )
176254 return _to_engine_value (output )
177255
178256 _WrappedClass .__name__ = cls .__name__
0 commit comments