1313from .typing import (
1414 KEY_FIELD_NAME ,
1515 TABLE_TYPES ,
16- DtypeRegistry ,
1716 analyze_type_info ,
1817 encode_enriched_type ,
19- extract_ndarray_scalar_dtype ,
2018 is_namedtuple_type ,
2119 is_struct_type ,
20+ AnalyzedTypeInfo ,
21+ AnalyzedAnyType ,
22+ AnalyzedDictType ,
23+ AnalyzedListType ,
24+ AnalyzedBasicType ,
25+ AnalyzedUnionType ,
26+ AnalyzedUnknownType ,
27+ AnalyzedStructType ,
28+ is_numpy_number_type ,
2229)
2330
2431
@@ -79,46 +86,88 @@ def make_engine_value_decoder(
7986 Returns:
8087 A decoder from an engine value to a Python value.
8188 """
89+
8290 src_type_kind = src_type ["kind" ]
8391
84- dst_is_any = (
85- dst_annotation is None
86- or dst_annotation is inspect .Parameter .empty
87- or dst_annotation is Any
88- )
89- if dst_is_any :
90- if src_type_kind == "Union" :
91- return lambda value : value [1 ]
92- if src_type_kind == "Struct" :
93- return _make_engine_struct_to_dict_decoder (field_path , src_type ["fields" ])
94- if src_type_kind in TABLE_TYPES :
95- if src_type_kind == "LTable" :
92+ dst_type_info = analyze_type_info (dst_annotation )
93+ dst_type_variant = dst_type_info .variant
94+
95+ if isinstance (dst_type_variant , AnalyzedUnknownType ):
96+ raise ValueError (
97+ f"Type mismatch for `{ '' .join (field_path )} `: "
98+ f"declared `{ dst_type_info .core_type } `, an unsupported type"
99+ )
100+
101+ if src_type_kind == "Struct" :
102+ return _make_engine_struct_value_decoder (
103+ field_path ,
104+ src_type ["fields" ],
105+ dst_type_info ,
106+ )
107+
108+ if src_type_kind in TABLE_TYPES :
109+ field_path .append ("[*]" )
110+ engine_fields_schema = src_type ["row" ]["fields" ]
111+
112+ if src_type_kind == "LTable" :
113+ if isinstance (dst_type_variant , AnalyzedAnyType ):
96114 return _make_engine_ltable_to_list_dict_decoder (
97- field_path , src_type ["row" ]["fields" ]
115+ field_path , engine_fields_schema
116+ )
117+ if not isinstance (dst_type_variant , AnalyzedListType ):
118+ raise ValueError (
119+ f"Type mismatch for `{ '' .join (field_path )} `: "
120+ f"declared `{ dst_type_info .core_type } `, a list type expected"
98121 )
99- elif src_type_kind == "KTable" :
122+ row_decoder = _make_engine_struct_value_decoder (
123+ field_path ,
124+ engine_fields_schema ,
125+ analyze_type_info (dst_type_variant .elem_type ),
126+ )
127+
128+ def decode (value : Any ) -> Any | None :
129+ if value is None :
130+ return None
131+ return [row_decoder (v ) for v in value ]
132+
133+ elif src_type_kind == "KTable" :
134+ if isinstance (dst_type_variant , AnalyzedAnyType ):
100135 return _make_engine_ktable_to_dict_dict_decoder (
101- field_path , src_type ["row" ]["fields" ]
136+ field_path , engine_fields_schema
137+ )
138+ if not isinstance (dst_type_variant , AnalyzedDictType ):
139+ raise ValueError (
140+ f"Type mismatch for `{ '' .join (field_path )} `: "
141+ f"declared `{ dst_type_info .core_type } `, a dict type expected"
102142 )
103- return lambda value : value
104143
105- # Handle struct -> dict binding for explicit dict annotations
106- is_dict_annotation = False
107- if dst_annotation is dict :
108- is_dict_annotation = True
109- elif getattr (dst_annotation , "__origin__" , None ) is dict :
110- args = getattr (dst_annotation , "__args__" , ())
111- if args == (str , Any ):
112- is_dict_annotation = True
113- if is_dict_annotation and src_type_kind == "Struct" :
114- return _make_engine_struct_to_dict_decoder (field_path , src_type ["fields" ])
144+ key_field_schema = engine_fields_schema [0 ]
145+ field_path .append (f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " )
146+ key_decoder = make_engine_value_decoder (
147+ field_path , key_field_schema ["type" ], dst_type_variant .key_type
148+ )
149+ field_path .pop ()
150+ value_decoder = _make_engine_struct_value_decoder (
151+ field_path ,
152+ engine_fields_schema [1 :],
153+ analyze_type_info (dst_type_variant .value_type ),
154+ )
115155
116- dst_type_info = analyze_type_info (dst_annotation )
156+ def decode (value : Any ) -> Any | None :
157+ if value is None :
158+ return None
159+ return {key_decoder (v [0 ]): value_decoder (v [1 :]) for v in value }
160+
161+ field_path .pop ()
162+ return decode
117163
118164 if src_type_kind == "Union" :
165+ if isinstance (dst_type_variant , AnalyzedAnyType ):
166+ return lambda value : value [1 ]
167+
119168 dst_type_variants = (
120- dst_type_info . union_variant_types
121- if dst_type_info . union_variant_types is not None
169+ dst_type_variant . variant_types
170+ if isinstance ( dst_type_variant , AnalyzedUnionType )
122171 else [dst_annotation ]
123172 )
124173 src_type_variants = src_type ["types" ]
@@ -142,43 +191,36 @@ def make_engine_value_decoder(
142191 decoders .append (decoder )
143192 return lambda value : decoders [value [0 ]](value [1 ])
144193
145- if not _is_type_kind_convertible_to (src_type_kind , dst_type_info .kind ):
146- raise ValueError (
147- f"Type mismatch for `{ '' .join (field_path )} `: "
148- f"passed in { src_type_kind } , declared { dst_annotation } ({ dst_type_info .kind } )"
149- )
150-
151- if dst_type_info .kind in ("Float32" , "Float64" , "Int64" ):
152- dst_core_type = dst_type_info .core_type
153-
154- def decode_scalar (value : Any ) -> Any | None :
155- if value is None :
156- if dst_type_info .nullable :
157- return None
158- raise ValueError (
159- f"Received null for non-nullable scalar `{ '' .join (field_path )} `"
160- )
161- return dst_core_type (value )
162-
163- return decode_scalar
194+ if isinstance (dst_type_variant , AnalyzedAnyType ):
195+ return lambda value : value
164196
165197 if src_type_kind == "Vector" :
166198 field_path_str = "" .join (field_path )
199+ if not isinstance (dst_type_variant , AnalyzedListType ):
200+ raise ValueError (
201+ f"Type mismatch for `{ '' .join (field_path )} `: "
202+ f"declared `{ dst_type_info .core_type } `, a list type expected"
203+ )
167204 expected_dim = (
168- dst_type_info .vector_info .dim if dst_type_info .vector_info else None
205+ dst_type_variant .vector_info .dim
206+ if dst_type_variant and dst_type_variant .vector_info
207+ else None
169208 )
170209
171- elem_decoder = None
210+ vec_elem_decoder = None
172211 scalar_dtype = None
173- if dst_type_info .np_number_type is None : # for Non-NDArray vector
174- elem_decoder = make_engine_value_decoder (
212+ if (
213+ dst_type_variant
214+ and is_numpy_number_type (dst_type_variant .elem_type )
215+ and dst_type_info .base_type is np .ndarray
216+ ):
217+ scalar_dtype = dst_type_variant .elem_type
218+ else :
219+ vec_elem_decoder = make_engine_value_decoder (
175220 field_path + ["[*]" ],
176221 src_type ["element_type" ],
177- dst_type_info .elem_type ,
222+ dst_type_variant and dst_type_variant .elem_type ,
178223 )
179- else : # for NDArray vector
180- scalar_dtype = extract_ndarray_scalar_dtype (dst_type_info .np_number_type )
181- _ = DtypeRegistry .validate_dtype_and_get_kind (scalar_dtype )
182224
183225 def decode_vector (value : Any ) -> Any | None :
184226 if value is None :
@@ -197,66 +239,70 @@ def decode_vector(value: Any) -> Any | None:
197239 f"expected { expected_dim } , got { len (value )} "
198240 )
199241
200- if elem_decoder is not None : # for Non-NDArray vector
201- return [elem_decoder (v ) for v in value ]
242+ if vec_elem_decoder is not None : # for Non-NDArray vector
243+ return [vec_elem_decoder (v ) for v in value ]
202244 else : # for NDArray vector
203245 return np .array (value , dtype = scalar_dtype )
204246
205247 return decode_vector
206248
207- if dst_type_info .struct_type is not None :
208- return _make_engine_struct_value_decoder (
209- field_path , src_type ["fields" ], dst_type_info .struct_type
210- )
211-
212- if src_type_kind in TABLE_TYPES :
213- field_path .append ("[*]" )
214- elem_type_info = analyze_type_info (dst_type_info .elem_type )
215- if elem_type_info .struct_type is None :
249+ if isinstance (dst_type_variant , AnalyzedBasicType ):
250+ if not _is_type_kind_convertible_to (src_type_kind , dst_type_variant .kind ):
216251 raise ValueError (
217252 f"Type mismatch for `{ '' .join (field_path )} `: "
218- f"declared `{ dst_type_info .kind } `, a dataclass or NamedTuple type expected"
219- )
220- engine_fields_schema = src_type ["row" ]["fields" ]
221- if elem_type_info .key_type is not None :
222- key_field_schema = engine_fields_schema [0 ]
223- field_path .append (f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " )
224- key_decoder = make_engine_value_decoder (
225- field_path , key_field_schema ["type" ], elem_type_info .key_type
226- )
227- field_path .pop ()
228- value_decoder = _make_engine_struct_value_decoder (
229- field_path , engine_fields_schema [1 :], elem_type_info .struct_type
253+ f"passed in { src_type_kind } , declared { dst_annotation } ({ dst_type_variant .kind } )"
230254 )
231255
232- def decode (value : Any ) -> Any | None :
233- if value is None :
234- return None
235- return {key_decoder (v [0 ]): value_decoder (v [1 :]) for v in value }
236- else :
237- elem_decoder = _make_engine_struct_value_decoder (
238- field_path , engine_fields_schema , elem_type_info .struct_type
239- )
256+ if dst_type_variant .kind in ("Float32" , "Float64" , "Int64" ):
257+ dst_core_type = dst_type_info .core_type
240258
241- def decode (value : Any ) -> Any | None :
259+ def decode_scalar (value : Any ) -> Any | None :
242260 if value is None :
243- return None
244- return [elem_decoder (v ) for v in value ]
261+ if dst_type_info .nullable :
262+ return None
263+ raise ValueError (
264+ f"Received null for non-nullable scalar `{ '' .join (field_path )} `"
265+ )
266+ return dst_core_type (value )
245267
246- field_path .pop ()
247- return decode
268+ return decode_scalar
248269
249270 return lambda value : value
250271
251272
252273def _make_engine_struct_value_decoder (
253274 field_path : list [str ],
254275 src_fields : list [dict [str , Any ]],
255- dst_struct_type : type ,
276+ dst_type_info : AnalyzedTypeInfo ,
256277) -> Callable [[list [Any ]], Any ]:
257278 """Make a decoder from an engine field values to a Python value."""
258279
280+ dst_type_variant = dst_type_info .variant
281+
282+ use_dict = False
283+ if isinstance (dst_type_variant , AnalyzedAnyType ):
284+ use_dict = True
285+ elif isinstance (dst_type_variant , AnalyzedDictType ):
286+ analyzed_key_type = analyze_type_info (dst_type_variant .key_type )
287+ analyzed_value_type = analyze_type_info (dst_type_variant .value_type )
288+ use_dict = (
289+ isinstance (analyzed_key_type .variant , AnalyzedAnyType )
290+ or (
291+ isinstance (analyzed_key_type .variant , AnalyzedBasicType )
292+ and analyzed_key_type .variant .kind == "Str"
293+ )
294+ ) and isinstance (analyzed_value_type .variant , AnalyzedAnyType )
295+ if use_dict :
296+ return _make_engine_struct_to_dict_decoder (field_path , src_fields )
297+
298+ if not isinstance (dst_type_variant , AnalyzedStructType ):
299+ raise ValueError (
300+ f"Type mismatch for `{ '' .join (field_path )} `: "
301+ f"declared `{ dst_type_info .core_type } `, a dataclass, NamedTuple or dict[str, Any] expected"
302+ )
303+
259304 src_name_to_idx = {f ["name" ]: i for i , f in enumerate (src_fields )}
305+ dst_struct_type = dst_type_variant .struct_type
260306
261307 parameters : Mapping [str , inspect .Parameter ]
262308 if dataclasses .is_dataclass (dst_struct_type ):
0 commit comments