@@ -86,6 +86,14 @@ def make_engine_value_decoder(
8686 or dst_annotation is inspect .Parameter .empty
8787 or dst_annotation is Any
8888 )
89+ # Handle struct -> dict binding for explicit dict annotations
90+ is_dict_annotation = False
91+ if dst_annotation is dict :
92+ is_dict_annotation = True
93+ elif getattr (dst_annotation , "__origin__" , None ) is dict :
94+ args = getattr (dst_annotation , "__args__" , ())
95+ if args == (str , Any ):
96+ is_dict_annotation = True
8997 if dst_is_any :
9098 if src_type_kind == "Union" :
9199 return lambda value : value [1 ]
@@ -97,21 +105,11 @@ def make_engine_value_decoder(
97105 f"It's required for { src_type_kind } type."
98106 )
99107 return lambda value : value
108+ if is_dict_annotation and src_type_kind == "Struct" :
109+ return _make_engine_struct_to_dict_decoder (field_path , src_type ["fields" ])
100110
101111 dst_type_info = analyze_type_info (dst_annotation )
102112
103- # Handle struct -> dict binding for explicit dict annotations
104- if (
105- src_type_kind == "Struct"
106- and dst_type_info .kind == "KTable"
107- and dst_type_info .elem_type
108- and isinstance (dst_type_info .elem_type , tuple )
109- and len (dst_type_info .elem_type ) == 2
110- and dst_type_info .elem_type [0 ] is str
111- and dst_type_info .elem_type [1 ] is Any
112- ):
113- return _make_engine_struct_to_dict_decoder (field_path , src_type ["fields" ])
114-
115113 if src_type_kind == "Union" :
116114 dst_type_variants = (
117115 dst_type_info .union_variant_types
@@ -311,7 +309,7 @@ def make_closure_for_value(
311309def _make_engine_struct_to_dict_decoder (
312310 field_path : list [str ],
313311 src_fields : list [dict [str , Any ]],
314- ) -> Callable [[list [Any ]], dict [str , Any ]]:
312+ ) -> Callable [[list [Any ] | None ], dict [str , Any ] | None ]:
315313 """Make a decoder from engine field values to a Python dict."""
316314
317315 field_decoders = []
@@ -326,12 +324,17 @@ def _make_engine_struct_to_dict_decoder(
326324 field_path .pop ()
327325 field_decoders .append ((field_name , field_decoder ))
328326
329- def decode_to_dict (values : list [Any ]) -> dict [str , Any ]:
330- result = {}
331- for i , (field_name , field_decoder ) in enumerate (field_decoders ):
332- if i < len (values ):
333- result [field_name ] = field_decoder (values [i ])
334- return result
327+ def decode_to_dict (values : list [Any ] | None ) -> dict [str , Any ] | None :
328+ if values is None :
329+ return None
330+ if len (field_decoders ) != len (values ):
331+ raise ValueError (
332+ f"Field count mismatch: expected { len (field_decoders )} , got { len (values )} "
333+ )
334+ return {
335+ field_name : field_decoder (value )
336+ for value , (field_name , field_decoder ) in zip (values , field_decoders )
337+ }
335338
336339 return decode_to_dict
337340
0 commit comments