22Utilities to convert between Python and engine values.
33"""
44
5+ from __future__ import annotations
6+
57import dataclasses
68import datetime
79import inspect
2830 is_struct_type ,
2931)
3032
33+ class ChildFieldPath :
34+ """Context manager to append a field to field_path on enter and pop it on exit."""
35+
36+ _field_path : list [str ]
37+ _field_name : str
38+
39+ def __init__ (self , field_path : list [str ], field_name : str ):
40+ self ._field_path : list [str ] = field_path
41+ self ._field_name = field_name
42+
43+ def __enter__ (self ) -> ChildFieldPath :
44+ self ._field_path .append (self ._field_name )
45+ return self
46+
47+ def __exit__ (self , _exc_type : Any , _exc_val : Any , _exc_tb : Any ) -> None :
48+ self ._field_path .pop ()
49+
50+
3151_CONVERTIBLE_KINDS = {
3252 ("Float32" , "Float64" ),
3353 ("LocalDateTime" , "OffsetDateTime" ),
@@ -48,7 +68,6 @@ def _encode_engine_value_core(
4868 type_variant : AnalyzedTypeInfo | None = None ,
4969) -> Any :
5070 """Core encoding logic for converting Python values to engine values."""
51-
5271 if dataclasses .is_dataclass (value ):
5372 fields = dataclasses .fields (value )
5473 return [
@@ -200,66 +219,65 @@ def make_engine_value_decoder(
200219 )
201220
202221 if src_type_kind == "Struct" :
203- return _make_engine_struct_value_decoder (
222+ return make_engine_struct_decoder (
204223 field_path ,
205224 src_type ["fields" ],
206225 dst_type_info ,
207226 )
208227
209228 if src_type_kind in TABLE_TYPES :
210- field_path . append ( "[*]" )
211- engine_fields_schema = src_type ["row" ]["fields" ]
229+ with ChildFieldPath ( field_path , "[*]" ):
230+ engine_fields_schema = src_type ["row" ]["fields" ]
212231
213- if src_type_kind == "LTable" :
214- if isinstance (dst_type_variant , AnalyzedAnyType ):
215- return _make_engine_ltable_to_list_dict_decoder (
216- field_path , engine_fields_schema
217- )
218- if not isinstance (dst_type_variant , AnalyzedListType ):
219- raise ValueError (
220- f"Type mismatch for `{ '' .join (field_path )} `: "
221- f"declared `{ dst_type_info .core_type } `, a list type expected"
232+ if src_type_kind == "LTable" :
233+ if isinstance (dst_type_variant , AnalyzedAnyType ):
234+ return _make_engine_ltable_to_list_dict_decoder (
235+ field_path , engine_fields_schema
236+ )
237+ if not isinstance (dst_type_variant , AnalyzedListType ):
238+ raise ValueError (
239+ f"Type mismatch for `{ '' .join (field_path )} `: "
240+ f"declared `{ dst_type_info .core_type } `, a list type expected"
241+ )
242+ row_decoder = make_engine_struct_decoder (
243+ field_path ,
244+ engine_fields_schema ,
245+ analyze_type_info (dst_type_variant .elem_type ),
222246 )
223- row_decoder = _make_engine_struct_value_decoder (
224- field_path ,
225- engine_fields_schema ,
226- analyze_type_info (dst_type_variant .elem_type ),
227- )
228247
229- def decode (value : Any ) -> Any | None :
230- if value is None :
231- return None
232- return [row_decoder (v ) for v in value ]
248+ def decode (value : Any ) -> Any | None :
249+ if value is None :
250+ return None
251+ return [row_decoder (v ) for v in value ]
233252
234- elif src_type_kind == "KTable" :
235- if isinstance (dst_type_variant , AnalyzedAnyType ):
236- return _make_engine_ktable_to_dict_dict_decoder (
237- field_path , engine_fields_schema
253+ elif src_type_kind == "KTable" :
254+ if isinstance (dst_type_variant , AnalyzedAnyType ):
255+ return _make_engine_ktable_to_dict_dict_decoder (
256+ field_path , engine_fields_schema
257+ )
258+ if not isinstance (dst_type_variant , AnalyzedDictType ):
259+ raise ValueError (
260+ f"Type mismatch for `{ '' .join (field_path )} `: "
261+ f"declared `{ dst_type_info .core_type } `, a dict type expected"
262+ )
263+
264+ key_field_schema = engine_fields_schema [0 ]
265+ field_path .append (f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " )
266+ key_decoder = make_engine_value_decoder (
267+ field_path , key_field_schema ["type" ], dst_type_variant .key_type
238268 )
239- if not isinstance (dst_type_variant , AnalyzedDictType ):
240- raise ValueError (
241- f"Type mismatch for `{ '' .join (field_path )} `: "
242- f"declared `{ dst_type_info .core_type } `, a dict type expected"
269+ field_path .pop ()
270+ value_decoder = make_engine_struct_decoder (
271+ field_path ,
272+ engine_fields_schema [1 :],
273+ analyze_type_info (dst_type_variant .value_type ),
243274 )
244275
245- key_field_schema = engine_fields_schema [0 ]
246- field_path .append (f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " )
247- key_decoder = make_engine_value_decoder (
248- field_path , key_field_schema ["type" ], dst_type_variant .key_type
249- )
250- field_path .pop ()
251- value_decoder = _make_engine_struct_value_decoder (
252- field_path ,
253- engine_fields_schema [1 :],
254- analyze_type_info (dst_type_variant .value_type ),
255- )
256-
257- def decode (value : Any ) -> Any | None :
258- if value is None :
259- return None
260- return {key_decoder (v [0 ]): value_decoder (v [1 :]) for v in value }
276+ def decode (value : Any ) -> Any | None :
277+ if value is None :
278+ return None
279+ return {key_decoder (v [0 ]): value_decoder (v [1 :]) for v in value }
261280
262- field_path .pop ()
263281 return decode
264282
265283 if src_type_kind == "Union" :
@@ -274,22 +292,22 @@ def decode(value: Any) -> Any | None:
274292 src_type_variants = src_type ["types" ]
275293 decoders = []
276294 for i , src_type_variant in enumerate (src_type_variants ):
277- src_field_path = field_path + [f"[{ i } ]" ]
278- decoder = None
279- for dst_type_variant in dst_type_variants :
280- try :
281- decoder = make_engine_value_decoder (
282- src_field_path , src_type_variant , dst_type_variant
295+ with ChildFieldPath (field_path , f"[{ i } ]" ):
296+ decoder = None
297+ for dst_type_variant in dst_type_variants :
298+ try :
299+ decoder = make_engine_value_decoder (
300+ field_path , src_type_variant , dst_type_variant
301+ )
302+ break
303+ except ValueError :
304+ pass
305+ if decoder is None :
306+ raise ValueError (
307+ f"Type mismatch for `{ '' .join (field_path )} `: "
308+ f"cannot find matched target type for source type variant { src_type_variant } "
283309 )
284- break
285- except ValueError :
286- pass
287- if decoder is None :
288- raise ValueError (
289- f"Type mismatch for `{ '' .join (field_path )} `: "
290- f"cannot find matched target type for source type variant { src_type_variant } "
291- )
292- decoders .append (decoder )
310+ decoders .append (decoder )
293311 return lambda value : decoders [value [0 ]](value [1 ])
294312
295313 if isinstance (dst_type_variant , AnalyzedAnyType ):
@@ -368,7 +386,7 @@ def decode_scalar(value: Any) -> Any | None:
368386 return lambda value : value
369387
370388
371- def _make_engine_struct_value_decoder (
389+ def make_engine_struct_decoder (
372390 field_path : list [str ],
373391 src_fields : list [dict [str , Any ]],
374392 dst_type_info : AnalyzedTypeInfo ,
@@ -426,25 +444,24 @@ def make_closure_for_value(
426444 name : str , param : inspect .Parameter
427445 ) -> Callable [[list [Any ]], Any ]:
428446 src_idx = src_name_to_idx .get (name )
429- if src_idx is not None :
430- field_path .append (f".{ name } " )
431- field_decoder = make_engine_value_decoder (
432- field_path , src_fields [src_idx ]["type" ], param .annotation
433- )
434- field_path .pop ()
435- return (
436- lambda values : field_decoder (values [src_idx ])
437- if len (values ) > src_idx
438- else param .default
439- )
447+ with ChildFieldPath (field_path , f".{ name } " ):
448+ if src_idx is not None :
449+ field_decoder = make_engine_value_decoder (
450+ field_path , src_fields [src_idx ]["type" ], param .annotation
451+ )
452+ return (
453+ lambda values : field_decoder (values [src_idx ])
454+ if len (values ) > src_idx
455+ else param .default
456+ )
440457
441- default_value = param .default
442- if default_value is inspect .Parameter .empty :
443- raise ValueError (
444- f"Field without default value is missing in input: { '' .join (field_path )} "
445- )
458+ default_value = param .default
459+ if default_value is inspect .Parameter .empty :
460+ raise ValueError (
461+ f"Field without default value is missing in input: { '' .join (field_path )} "
462+ )
446463
447- return lambda _ : default_value
464+ return lambda _ : default_value
448465
449466 field_value_decoder = [
450467 make_closure_for_value (name , param ) for (name , param ) in parameters .items ()
@@ -464,13 +481,12 @@ def _make_engine_struct_to_dict_decoder(
464481 field_decoders = []
465482 for i , field_schema in enumerate (src_fields ):
466483 field_name = field_schema ["name" ]
467- field_path .append (f".{ field_name } " )
468- field_decoder = make_engine_value_decoder (
469- field_path ,
470- field_schema ["type" ],
471- Any , # Use Any for recursive decoding
472- )
473- field_path .pop ()
484+ with ChildFieldPath (field_path , f".{ field_name } " ):
485+ field_decoder = make_engine_value_decoder (
486+ field_path ,
487+ field_schema ["type" ],
488+ Any , # Use Any for recursive decoding
489+ )
474490 field_decoders .append ((field_name , field_decoder ))
475491
476492 def decode_to_dict (values : list [Any ] | None ) -> dict [str , Any ] | None :
@@ -527,9 +543,10 @@ def _make_engine_ktable_to_dict_dict_decoder(
527543 value_fields_schema = src_fields [1 :]
528544
529545 # Create decoders
530- field_path .append (f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " )
531- key_decoder = make_engine_value_decoder (field_path , key_field_schema ["type" ], Any )
532- field_path .pop ()
546+ with ChildFieldPath (field_path , f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " ):
547+ key_decoder = make_engine_value_decoder (
548+ field_path , key_field_schema ["type" ], Any
549+ )
533550
534551 value_decoder = _make_engine_struct_to_dict_decoder (field_path , value_fields_schema )
535552
0 commit comments