99import inspect
1010import warnings
1111from enum import Enum
12- from typing import Any , Callable , Mapping , Sequence , Type , get_origin
12+ from typing import Any , Callable , Mapping , get_origin
1313
1414import numpy as np
1515
1616from .typing import (
17- TABLE_TYPES ,
1817 AnalyzedAnyType ,
1918 AnalyzedBasicType ,
2019 AnalyzedDictType ,
2726 encode_enriched_type ,
2827 is_namedtuple_type ,
2928 is_numpy_number_type ,
29+ ValueType ,
30+ FieldSchema ,
31+ BasicValueType ,
32+ StructType ,
33+ TableType ,
3034)
3135
3236
@@ -172,7 +176,7 @@ def encode_basic_value(value: Any) -> Any:
172176
173177def make_engine_key_decoder (
174178 field_path : list [str ],
175- key_fields_schema : list [dict [ str , Any ] ],
179+ key_fields_schema : list [FieldSchema ],
176180 dst_type_info : AnalyzedTypeInfo ,
177181) -> Callable [[Any ], Any ]:
178182 """
@@ -183,7 +187,7 @@ def make_engine_key_decoder(
183187 ):
184188 single_key_decoder = make_engine_value_decoder (
185189 field_path ,
186- key_fields_schema [0 ][ " type" ] ,
190+ key_fields_schema [0 ]. value_type . type ,
187191 dst_type_info ,
188192 for_key = True ,
189193 )
@@ -203,7 +207,7 @@ def key_decoder(value: list[Any]) -> Any:
203207
204208def make_engine_value_decoder (
205209 field_path : list [str ],
206- src_type : dict [ str , Any ] ,
210+ src_type : ValueType ,
207211 dst_type_info : AnalyzedTypeInfo ,
208212 for_key : bool = False ,
209213) -> Callable [[Any ], Any ]:
@@ -219,7 +223,7 @@ def make_engine_value_decoder(
219223 A decoder from an engine value to a Python value.
220224 """
221225
222- src_type_kind = src_type [ " kind" ]
226+ src_type_kind = src_type . kind
223227
224228 dst_type_variant = dst_type_info .variant
225229
@@ -229,19 +233,19 @@ def make_engine_value_decoder(
229233 f"declared `{ dst_type_info .core_type } `, an unsupported type"
230234 )
231235
232- if src_type_kind == "Struct" :
236+ if isinstance ( src_type , StructType ): # type: ignore[redundant-cast]
233237 return make_engine_struct_decoder (
234238 field_path ,
235- src_type [ " fields" ] ,
239+ src_type . fields ,
236240 dst_type_info ,
237241 for_key = for_key ,
238242 )
239243
240- if src_type_kind in TABLE_TYPES :
244+ if isinstance ( src_type , TableType ): # type: ignore[redundant-cast]
241245 with ChildFieldPath (field_path , "[*]" ):
242- engine_fields_schema = src_type [ " row" ][ " fields" ]
246+ engine_fields_schema = src_type . row . fields
243247
244- if src_type_kind == "LTable" :
248+ if src_type . kind == "LTable" :
245249 if isinstance (dst_type_variant , AnalyzedAnyType ):
246250 dst_elem_type = Any
247251 elif isinstance (dst_type_variant , AnalyzedListType ):
@@ -262,7 +266,7 @@ def decode(value: Any) -> Any | None:
262266 return None
263267 return [row_decoder (v ) for v in value ]
264268
265- elif src_type_kind == "KTable" :
269+ elif src_type . kind == "KTable" :
266270 if isinstance (dst_type_variant , AnalyzedAnyType ):
267271 key_type , value_type = Any , Any
268272 elif isinstance (dst_type_variant , AnalyzedDictType ):
@@ -274,7 +278,7 @@ def decode(value: Any) -> Any | None:
274278 f"declared `{ dst_type_info .core_type } `, a dict type expected"
275279 )
276280
277- num_key_parts = src_type .get ( " num_key_parts" , 1 )
281+ num_key_parts = src_type .num_key_parts or 1
278282 key_decoder = make_engine_key_decoder (
279283 field_path ,
280284 engine_fields_schema [0 :num_key_parts ],
@@ -298,7 +302,7 @@ def decode(value: Any) -> Any | None:
298302
299303 return decode
300304
301- if src_type_kind == "Union" :
305+ if isinstance ( src_type , BasicValueType ) and src_type . kind == "Union" :
302306 if isinstance (dst_type_variant , AnalyzedAnyType ):
303307 return lambda value : value [1 ]
304308
@@ -307,7 +311,10 @@ def decode(value: Any) -> Any | None:
307311 if isinstance (dst_type_variant , AnalyzedUnionType )
308312 else [dst_type_info ]
309313 )
310- src_type_variants = src_type ["types" ]
314+ # mypy: union info exists for Union kind
315+ assert src_type .union is not None # type: ignore[unreachable]
316+ src_type_variants_basic : list [BasicValueType ] = src_type .union .variants
317+ src_type_variants = src_type_variants_basic
311318 decoders = []
312319 for i , src_type_variant in enumerate (src_type_variants ):
313320 with ChildFieldPath (field_path , f"[{ i } ]" ):
@@ -331,7 +338,7 @@ def decode(value: Any) -> Any | None:
331338 if isinstance (dst_type_variant , AnalyzedAnyType ):
332339 return lambda value : value
333340
334- if src_type_kind == "Vector" :
341+ if isinstance ( src_type , BasicValueType ) and src_type . kind == "Vector" :
335342 field_path_str = "" .join (field_path )
336343 if not isinstance (dst_type_variant , AnalyzedListType ):
337344 raise ValueError (
@@ -350,9 +357,11 @@ def decode(value: Any) -> Any | None:
350357 if is_numpy_number_type (dst_type_variant .elem_type ):
351358 scalar_dtype = dst_type_variant .elem_type
352359 else :
360+ # mypy: vector info exists for Vector kind
361+ assert src_type .vector is not None # type: ignore[unreachable]
353362 vec_elem_decoder = make_engine_value_decoder (
354363 field_path + ["[*]" ],
355- src_type [ " element_type" ] ,
364+ src_type . vector . element_type ,
356365 analyze_type_info (
357366 dst_type_variant .elem_type if dst_type_variant else Any
358367 ),
@@ -432,7 +441,7 @@ def _get_auto_default_for_type(
432441
433442def make_engine_struct_decoder (
434443 field_path : list [str ],
435- src_fields : list [dict [ str , Any ] ],
444+ src_fields : list [FieldSchema ],
436445 dst_type_info : AnalyzedTypeInfo ,
437446 for_key : bool = False ,
438447) -> Callable [[list [Any ]], Any ]:
@@ -461,7 +470,7 @@ def make_engine_struct_decoder(
461470 f"declared `{ dst_type_info .core_type } `, a dataclass, NamedTuple or dict[str, Any] expected"
462471 )
463472
464- src_name_to_idx = {f [ " name" ] : i for i , f in enumerate (src_fields )}
473+ src_name_to_idx = {f . name : i for i , f in enumerate (src_fields )}
465474 dst_struct_type = dst_type_variant .struct_type
466475
467476 parameters : Mapping [str , inspect .Parameter ]
@@ -493,7 +502,10 @@ def make_closure_for_field(
493502 with ChildFieldPath (field_path , f".{ name } " ):
494503 if src_idx is not None :
495504 field_decoder = make_engine_value_decoder (
496- field_path , src_fields [src_idx ]["type" ], type_info , for_key = for_key
505+ field_path ,
506+ src_fields [src_idx ].value_type .type ,
507+ type_info ,
508+ for_key = for_key ,
497509 )
498510 return lambda values : field_decoder (values [src_idx ])
499511
@@ -526,19 +538,19 @@ def make_closure_for_field(
526538
527539def _make_engine_struct_to_dict_decoder (
528540 field_path : list [str ],
529- src_fields : list [dict [ str , Any ] ],
541+ src_fields : list [FieldSchema ],
530542 value_type_annotation : Any ,
531543) -> Callable [[list [Any ] | None ], dict [str , Any ] | None ]:
532544 """Make a decoder from engine field values to a Python dict."""
533545
534546 field_decoders = []
535547 value_type_info = analyze_type_info (value_type_annotation )
536548 for field_schema in src_fields :
537- field_name = field_schema [ " name" ]
549+ field_name = field_schema . name
538550 with ChildFieldPath (field_path , f".{ field_name } " ):
539551 field_decoder = make_engine_value_decoder (
540552 field_path ,
541- field_schema [ " type" ] ,
553+ field_schema . value_type . type ,
542554 value_type_info ,
543555 )
544556 field_decoders .append ((field_name , field_decoder ))
@@ -560,19 +572,19 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
560572
561573def _make_engine_struct_to_tuple_decoder (
562574 field_path : list [str ],
563- src_fields : list [dict [ str , Any ] ],
575+ src_fields : list [FieldSchema ],
564576) -> Callable [[list [Any ] | None ], tuple [Any , ...] | None ]:
565577 """Make a decoder from engine field values to a Python tuple."""
566578
567579 field_decoders = []
568580 value_type_info = analyze_type_info (Any )
569581 for field_schema in src_fields :
570- field_name = field_schema [ " name" ]
582+ field_name = field_schema . name
571583 with ChildFieldPath (field_path , f".{ field_name } " ):
572584 field_decoders .append (
573585 make_engine_value_decoder (
574586 field_path ,
575- field_schema [ " type" ] ,
587+ field_schema . value_type . type ,
576588 value_type_info ,
577589 )
578590 )
0 commit comments