@@ -95,6 +95,7 @@ def make_engine_value_decoder(
9595 field_path : list [str ],
9696 src_type : dict [str , Any ],
9797 dst_type_info : AnalyzedTypeInfo ,
98+ for_key : bool = False ,
9899) -> Callable [[Any ], Any ]:
99100 """
100101 Make a decoder from an engine value to a Python value.
@@ -123,6 +124,7 @@ def make_engine_value_decoder(
123124 field_path ,
124125 src_type ["fields" ],
125126 dst_type_info ,
127+ for_key = for_key ,
126128 )
127129
128130 if src_type_kind in TABLE_TYPES :
@@ -131,18 +133,18 @@ def make_engine_value_decoder(
131133
132134 if src_type_kind == "LTable" :
133135 if isinstance (dst_type_variant , AnalyzedAnyType ):
134- return _make_engine_ltable_to_list_dict_decoder (
135- field_path , engine_fields_schema
136- )
137- if not isinstance ( dst_type_variant , AnalyzedListType ) :
136+ dst_elem_type = Any
137+ elif isinstance ( dst_type_variant , AnalyzedListType ):
138+ dst_elem_type = dst_type_variant . elem_type
139+ else :
138140 raise ValueError (
139141 f"Type mismatch for `{ '' .join (field_path )} `: "
140142 f"declared `{ dst_type_info .core_type } `, a list type expected"
141143 )
142144 row_decoder = make_engine_struct_decoder (
143145 field_path ,
144146 engine_fields_schema ,
145- analyze_type_info (dst_type_variant . elem_type ),
147+ analyze_type_info (dst_elem_type ),
146148 )
147149
148150 def decode (value : Any ) -> Any | None :
@@ -152,10 +154,11 @@ def decode(value: Any) -> Any | None:
152154
153155 elif src_type_kind == "KTable" :
154156 if isinstance (dst_type_variant , AnalyzedAnyType ):
155- return _make_engine_ktable_to_dict_dict_decoder (
156- field_path , engine_fields_schema
157- )
158- if not isinstance (dst_type_variant , AnalyzedDictType ):
157+ key_type , value_type = Any , Any
158+ elif isinstance (dst_type_variant , AnalyzedDictType ):
159+ key_type = dst_type_variant .key_type
160+ value_type = dst_type_variant .value_type
161+ else :
159162 raise ValueError (
160163 f"Type mismatch for `{ '' .join (field_path )} `: "
161164 f"declared `{ dst_type_info .core_type } `, a dict type expected"
@@ -166,13 +169,14 @@ def decode(value: Any) -> Any | None:
166169 key_decoder = make_engine_value_decoder (
167170 field_path ,
168171 key_field_schema ["type" ],
169- analyze_type_info (dst_type_variant .key_type ),
172+ analyze_type_info (key_type ),
173+ for_key = True ,
170174 )
171175 field_path .pop ()
172176 value_decoder = make_engine_struct_decoder (
173177 field_path ,
174178 engine_fields_schema [1 :],
175- analyze_type_info (dst_type_variant . value_type ),
179+ analyze_type_info (value_type ),
176180 )
177181
178182 def decode (value : Any ) -> Any | None :
@@ -316,26 +320,26 @@ def make_engine_struct_decoder(
316320 field_path : list [str ],
317321 src_fields : list [dict [str , Any ]],
318322 dst_type_info : AnalyzedTypeInfo ,
323+ for_key : bool = False ,
319324) -> Callable [[list [Any ]], Any ]:
320325 """Make a decoder from an engine field values to a Python value."""
321326
322327 dst_type_variant = dst_type_info .variant
323328
324- use_dict = False
325329 if isinstance (dst_type_variant , AnalyzedAnyType ):
326- use_dict = True
330+ if for_key :
331+ return _make_engine_struct_to_tuple_decoder (field_path , src_fields )
332+ else :
333+ return _make_engine_struct_to_dict_decoder (field_path , src_fields , Any )
327334 elif isinstance (dst_type_variant , AnalyzedDictType ):
328335 analyzed_key_type = analyze_type_info (dst_type_variant .key_type )
329- analyzed_value_type = analyze_type_info (dst_type_variant .value_type )
330- use_dict = (
336+ if (
331337 isinstance (analyzed_key_type .variant , AnalyzedAnyType )
332- or (
333- isinstance (analyzed_key_type .variant , AnalyzedBasicType )
334- and analyzed_key_type .variant .kind == "Str"
338+ or analyzed_key_type .core_type is str
339+ ):
340+ return _make_engine_struct_to_dict_decoder (
341+ field_path , src_fields , dst_type_variant .value_type
335342 )
336- ) and isinstance (analyzed_value_type .variant , AnalyzedAnyType )
337- if use_dict :
338- return _make_engine_struct_to_dict_decoder (field_path , src_fields )
339343
340344 if not isinstance (dst_type_variant , AnalyzedStructType ):
341345 raise ValueError (
@@ -375,7 +379,7 @@ def make_closure_for_field(
375379 with ChildFieldPath (field_path , f".{ name } " ):
376380 if src_idx is not None :
377381 field_decoder = make_engine_value_decoder (
378- field_path , src_fields [src_idx ]["type" ], type_info
382+ field_path , src_fields [src_idx ]["type" ], type_info , for_key = for_key
379383 )
380384 return lambda values : field_decoder (values [src_idx ])
381385
@@ -409,17 +413,19 @@ def make_closure_for_field(
409413def _make_engine_struct_to_dict_decoder (
410414 field_path : list [str ],
411415 src_fields : list [dict [str , Any ]],
416+ value_type_annotation : Any ,
412417) -> Callable [[list [Any ] | None ], dict [str , Any ] | None ]:
413418 """Make a decoder from engine field values to a Python dict."""
414419
415420 field_decoders = []
416- for i , field_schema in enumerate (src_fields ):
421+ value_type_info = analyze_type_info (value_type_annotation )
422+ for field_schema in src_fields :
417423 field_name = field_schema ["name" ]
418424 with ChildFieldPath (field_path , f".{ field_name } " ):
419425 field_decoder = make_engine_value_decoder (
420426 field_path ,
421427 field_schema ["type" ],
422- analyze_type_info ( Any ), # Use Any for recursive decoding
428+ value_type_info ,
423429 )
424430 field_decoders .append ((field_name , field_decoder ))
425431
@@ -438,76 +444,37 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
438444 return decode_to_dict
439445
440446
441- def _make_engine_ltable_to_list_dict_decoder (
447+ def _make_engine_struct_to_tuple_decoder (
442448 field_path : list [str ],
443449 src_fields : list [dict [str , Any ]],
444- ) -> Callable [[list [Any ] | None ], list [dict [str , Any ]] | None ]:
445- """Make a decoder from engine LTable values to a list of dicts."""
446-
447- # Create a decoder for each row (struct) to dict
448- row_decoder = _make_engine_struct_to_dict_decoder (field_path , src_fields )
450+ ) -> Callable [[list [Any ] | None ], tuple [Any , ...] | None ]:
451+ """Make a decoder from engine field values to a Python tuple."""
449452
450- def decode_to_list_dict (values : list [Any ] | None ) -> list [dict [str , Any ]] | None :
451- if values is None :
452- return None
453- result = []
454- for i , row_values in enumerate (values ):
455- decoded_row = row_decoder (row_values )
456- if decoded_row is None :
457- raise ValueError (
458- f"LTable row at index { i } decoded to None, which is not allowed."
453+ field_decoders = []
454+ value_type_info = analyze_type_info (Any )
455+ for field_schema in src_fields :
456+ field_name = field_schema ["name" ]
457+ with ChildFieldPath (field_path , f".{ field_name } " ):
458+ field_decoders .append (
459+ make_engine_value_decoder (
460+ field_path ,
461+ field_schema ["type" ],
462+ value_type_info ,
459463 )
460- result .append (decoded_row )
461- return result
462-
463- return decode_to_list_dict
464-
465-
466- def _make_engine_ktable_to_dict_dict_decoder (
467- field_path : list [str ],
468- src_fields : list [dict [str , Any ]],
469- ) -> Callable [[list [Any ] | None ], dict [Any , dict [str , Any ]] | None ]:
470- """Make a decoder from engine KTable values to a dict of dicts."""
471-
472- if not src_fields :
473- raise ValueError ("KTable must have at least one field for the key" )
474-
475- # First field is the key, remaining fields are the value
476- key_field_schema = src_fields [0 ]
477- value_fields_schema = src_fields [1 :]
478-
479- # Create decoders
480- with ChildFieldPath (field_path , f".{ key_field_schema .get ('name' , KEY_FIELD_NAME )} " ):
481- key_decoder = make_engine_value_decoder (
482- field_path , key_field_schema ["type" ], analyze_type_info (Any )
483- )
484-
485- value_decoder = _make_engine_struct_to_dict_decoder (field_path , value_fields_schema )
464+ )
486465
487- def decode_to_dict_dict (
488- values : list [Any ] | None ,
489- ) -> dict [Any , dict [str , Any ]] | None :
466+ def decode_to_tuple (values : list [Any ] | None ) -> tuple [Any , ...] | None :
490467 if values is None :
491468 return None
492- result = {}
493- for row_values in values :
494- if not row_values :
495- raise ValueError ("KTable row must have at least 1 value (the key)" )
496- key = key_decoder (row_values [0 ])
497- if len (row_values ) == 1 :
498- value : dict [str , Any ] = {}
499- else :
500- tmp = value_decoder (row_values [1 :])
501- if tmp is None :
502- value = {}
503- else :
504- value = tmp
505- if isinstance (key , dict ):
506- key = tuple (key .values ())
507- result [key ] = value
508- return result
469+ if len (field_decoders ) != len (values ):
470+ raise ValueError (
471+ f"Field count mismatch: expected { len (field_decoders )} , got { len (values )} "
472+ )
473+ return tuple (
474+ field_decoder (value ) for value , field_decoder in zip (values , field_decoders )
475+ )
509476
510- return decode_to_dict_dict
477+ return decode_to_tuple
511478
512479
513480def dump_engine_object (v : Any ) -> Any :
0 commit comments