99import inspect
1010import warnings
1111from enum import Enum
12- from typing import Any , Callable , Mapping , get_origin
12+ from typing import Any , Callable , Mapping , get_origin , TypeVar , overload
1313
1414import numpy as np
1515
2626 encode_enriched_type ,
2727 is_namedtuple_type ,
2828 is_numpy_number_type ,
29+ extract_ndarray_elem_dtype ,
2930 ValueType ,
3031 FieldSchema ,
3132 BasicValueType ,
3435)
3536
3637
38+ T = TypeVar ("T" )
39+
40+
3741class ChildFieldPath :
3842 """Context manager to append a field to field_path on enter and pop it on exit."""
3943
@@ -616,7 +620,7 @@ def dump_engine_object(v: Any) -> Any:
616620 secs = int (total_secs )
617621 nanos = int ((total_secs - secs ) * 1e9 )
618622 return {"secs" : secs , "nanos" : nanos }
619- elif hasattr (v , "__dict__" ):
623+ elif hasattr (v , "__dict__" ): # for dataclass-like objects
620624 s = {}
621625 for k , val in v .__dict__ .items ():
622626 if val is None :
@@ -633,3 +637,128 @@ def dump_engine_object(v: Any) -> Any:
633637 elif isinstance (v , dict ):
634638 return {k : dump_engine_object (v ) for k , v in v .items ()}
635639 return v
640+
641+
642+ @overload
643+ def load_engine_object (expected_type : type [T ], v : Any ) -> T : ...
644+ @overload
645+ def load_engine_object (expected_type : Any , v : Any ) -> Any : ...
646+ def load_engine_object (expected_type : Any , v : Any ) -> Any :
647+ """Recursively load an object that was produced by dump_engine_object().
648+
649+ Args:
650+ expected_type: The Python type annotation to reconstruct to.
651+ v: The engine-facing Pythonized object (e.g., dict/list/primitive) to convert.
652+
653+ Returns:
654+ A Python object matching the expected_type where possible.
655+ """
656+ # Fast path
657+ if v is None :
658+ return None
659+
660+ type_info = analyze_type_info (expected_type )
661+ variant = type_info .variant
662+
663+ # Any or unknown → return as-is
664+ if isinstance (variant , AnalyzedAnyType ) or type_info .base_type is Any :
665+ return v
666+
667+ # Enum handling
668+ if isinstance (expected_type , type ) and issubclass (expected_type , Enum ):
669+ return expected_type (v )
670+
671+ # TimeDelta special form {secs, nanos}
672+ if isinstance (variant , AnalyzedBasicType ) and variant .kind == "TimeDelta" :
673+ if isinstance (v , Mapping ) and "secs" in v and "nanos" in v :
674+ secs = int (v ["secs" ]) # type: ignore[index]
675+ nanos = int (v ["nanos" ]) # type: ignore[index]
676+ return datetime .timedelta (seconds = secs , microseconds = nanos / 1_000 )
677+ return v
678+
679+ # List, NDArray (Vector-ish), or general sequences
680+ if isinstance (variant , AnalyzedListType ):
681+ elem_type = variant .elem_type if variant .elem_type else Any
682+ if type_info .base_type is np .ndarray :
683+ # Reconstruct NDArray with appropriate dtype if available
684+ try :
685+ dtype = extract_ndarray_elem_dtype (type_info .core_type )
686+ except (TypeError , ValueError , AttributeError ):
687+ dtype = None
688+ return np .array (v , dtype = dtype )
689+ # Regular Python list
690+ return [load_engine_object (elem_type , item ) for item in v ]
691+
692+ # Dict / Mapping
693+ if isinstance (variant , AnalyzedDictType ):
694+ key_t = variant .key_type
695+ val_t = variant .value_type
696+ return {
697+ load_engine_object (key_t , k ): load_engine_object (val_t , val )
698+ for k , val in v .items ()
699+ }
700+
701+ # Structs (dataclass or NamedTuple)
702+ if isinstance (variant , AnalyzedStructType ):
703+ struct_type = variant .struct_type
704+ if dataclasses .is_dataclass (struct_type ):
705+ # Drop auxiliary discriminator "kind" if present
706+ src = dict (v ) if isinstance (v , Mapping ) else v
707+ if isinstance (src , Mapping ):
708+ init_kwargs : dict [str , Any ] = {}
709+ field_types = {f .name : f .type for f in dataclasses .fields (struct_type )}
710+ for name , f_type in field_types .items ():
711+ if name in src :
712+ init_kwargs [name ] = load_engine_object (f_type , src [name ])
713+ # Construct with defaults for missing fields
714+ return struct_type (** init_kwargs )
715+ elif is_namedtuple_type (struct_type ):
716+ # NamedTuple is dumped as list/tuple of items
717+ annotations = getattr (struct_type , "__annotations__" , {})
718+ field_names = list (getattr (struct_type , "_fields" , ()))
719+ values : list [Any ] = []
720+ for name in field_names :
721+ f_type = annotations .get (name , Any )
722+ # Assume v is a sequence aligned with fields
723+ if isinstance (v , (list , tuple )):
724+ idx = field_names .index (name )
725+ values .append (load_engine_object (f_type , v [idx ]))
726+ elif isinstance (v , Mapping ):
727+ values .append (load_engine_object (f_type , v .get (name )))
728+ else :
729+ values .append (v )
730+ return struct_type (* values )
731+ return v
732+
733+ # Union with discriminator support via "kind"
734+ if isinstance (variant , AnalyzedUnionType ):
735+ if isinstance (v , Mapping ) and "kind" in v :
736+ discriminator = v ["kind" ]
737+ for typ in variant .variant_types :
738+ t_info = analyze_type_info (typ )
739+ if isinstance (t_info .variant , AnalyzedStructType ):
740+ t_struct = t_info .variant .struct_type
741+ candidate_kind = getattr (t_struct , "kind" , None )
742+ if candidate_kind == discriminator :
743+ # Remove discriminator for constructor
744+ v_wo_kind = dict (v )
745+ v_wo_kind .pop ("kind" , None )
746+ return load_engine_object (t_struct , v_wo_kind )
747+ # Fallback: try each variant until one succeeds
748+ for typ in variant .variant_types :
749+ try :
750+ return load_engine_object (typ , v )
751+ except (TypeError , ValueError ):
752+ continue
753+ return v
754+
755+ # Basic types and everything else: handle numpy scalars and passthrough
756+ if isinstance (v , np .ndarray ) and type_info .base_type is list :
757+ return v .tolist ()
758+ if isinstance (v , (list , tuple )) and type_info .base_type not in (list , tuple ):
759+ # If a non-sequence basic type expected, attempt direct cast
760+ try :
761+ return type_info .core_type (v )
762+ except (TypeError , ValueError ):
763+ return v
764+ return v
0 commit comments