44
55from __future__ import annotations
66
7- import dataclasses
87import inspect
98import warnings
10- from typing import Any , Callable , Mapping , TypeVar
9+ from typing import Any , Callable , TypeVar
1110
1211import numpy as np
1312from .typing import (
1918 AnalyzedTypeInfo ,
2019 AnalyzedUnionType ,
2120 AnalyzedUnknownType ,
21+ AnalyzedStructFieldInfo ,
2222 analyze_type_info ,
23- is_namedtuple_type ,
2423 is_pydantic_model ,
2524 is_numpy_number_type ,
2625 ValueType ,
@@ -124,69 +123,20 @@ def encode_struct_dict(value: Any) -> Any:
124123 return encode_struct_dict
125124
126125 if isinstance (variant , AnalyzedStructType ):
127- struct_type = variant .struct_type
128-
129- if dataclasses .is_dataclass (struct_type ):
130- fields = dataclasses .fields (struct_type )
131- field_encoders = [
132- make_engine_value_encoder (analyze_type_info (f .type )) for f in fields
133- ]
134- field_names = [f .name for f in fields ]
135-
136- def encode_dataclass (value : Any ) -> Any :
137- if value is None :
138- return None
139- return [
140- encoder (getattr (value , name ))
141- for encoder , name in zip (field_encoders , field_names )
142- ]
143-
144- return encode_dataclass
145-
146- elif is_namedtuple_type (struct_type ):
147- annotations = struct_type .__annotations__
148- field_names = list (getattr (struct_type , "_fields" , ()))
149- field_encoders = [
150- make_engine_value_encoder (
151- analyze_type_info (annotations [name ])
152- if name in annotations
153- else ANY_TYPE_INFO
154- )
155- for name in field_names
156- ]
157-
158- def encode_namedtuple (value : Any ) -> Any :
159- if value is None :
160- return None
161- return [
162- encoder (getattr (value , name ))
163- for encoder , name in zip (field_encoders , field_names )
164- ]
165-
166- return encode_namedtuple
167-
168- elif is_pydantic_model (struct_type ):
169- # Type guard: ensure we have model_fields attribute
170- if hasattr (struct_type , "model_fields" ):
171- field_names = list (struct_type .model_fields .keys ()) # type: ignore[attr-defined]
172- field_encoders = [
173- make_engine_value_encoder (
174- analyze_type_info (struct_type .model_fields [name ].annotation ) # type: ignore[attr-defined]
175- )
176- for name in field_names
177- ]
178- else :
179- raise ValueError (f"Invalid Pydantic model: { struct_type } " )
126+ field_encoders = [
127+ (
128+ field_info .name ,
129+ make_engine_value_encoder (analyze_type_info (field_info .type_hint )),
130+ )
131+ for field_info in variant .fields
132+ ]
180133
181- def encode_pydantic (value : Any ) -> Any :
182- if value is None :
183- return None
184- return [
185- encoder (getattr (value , name ))
186- for encoder , name in zip (field_encoders , field_names )
187- ]
134+ def encode_struct (value : Any ) -> Any :
135+ if value is None :
136+ return None
137+ return [encoder (getattr (value , name )) for name , encoder in field_encoders ]
188138
189- return encode_pydantic
139+ return encode_struct
190140
191141 def encode_basic_value (value : Any ) -> Any :
192142 if isinstance (value , np .number ):
@@ -475,51 +425,12 @@ def make_engine_struct_decoder(
475425 src_name_to_idx = {f .name : i for i , f in enumerate (src_fields )}
476426 dst_struct_type = dst_type_variant .struct_type
477427
478- parameters : Mapping [str , inspect .Parameter ]
479- if dataclasses .is_dataclass (dst_struct_type ):
480- parameters = inspect .signature (dst_struct_type ).parameters
481- elif is_namedtuple_type (dst_struct_type ):
482- defaults = getattr (dst_struct_type , "_field_defaults" , {})
483- fields = getattr (dst_struct_type , "_fields" , ())
484- parameters = {
485- name : inspect .Parameter (
486- name = name ,
487- kind = inspect .Parameter .POSITIONAL_OR_KEYWORD ,
488- default = defaults .get (name , inspect .Parameter .empty ),
489- annotation = dst_struct_type .__annotations__ .get (
490- name , inspect .Parameter .empty
491- ),
492- )
493- for name in fields
494- }
495- elif is_pydantic_model (dst_struct_type ):
496- # For Pydantic models, we can use model_fields to get field information
497- parameters = {}
498- # Type guard: ensure we have model_fields attribute
499- if hasattr (dst_struct_type , "model_fields" ):
500- model_fields = dst_struct_type .model_fields # type: ignore[attr-defined]
501- else :
502- model_fields = {}
503- for name , field_info in model_fields .items ():
504- default_value = (
505- field_info .default
506- if field_info .default is not ...
507- else inspect .Parameter .empty
508- )
509- parameters [name ] = inspect .Parameter (
510- name = name ,
511- kind = inspect .Parameter .POSITIONAL_OR_KEYWORD ,
512- default = default_value ,
513- annotation = field_info .annotation ,
514- )
515- else :
516- raise ValueError (f"Unsupported struct type: { dst_struct_type } " )
517-
518428 def make_closure_for_field (
519- name : str , param : inspect . Parameter
429+ field_info : AnalyzedStructFieldInfo ,
520430 ) -> Callable [[list [Any ]], Any ]:
431+ name = field_info .name
521432 src_idx = src_name_to_idx .get (name )
522- type_info = analyze_type_info (param . annotation )
433+ type_info = analyze_type_info (field_info . type_hint )
523434
524435 with ChildFieldPath (field_path , f".{ name } " ):
525436 if src_idx is not None :
@@ -531,42 +442,44 @@ def make_closure_for_field(
531442 )
532443 return lambda values : field_decoder (values [src_idx ])
533444
534- default_value = param . default
445+ default_value = field_info . default_value
535446 if default_value is not inspect .Parameter .empty :
536447 return lambda _ : default_value
537448
538449 auto_default , is_supported = get_auto_default_for_type (type_info )
539450 if is_supported :
540451 warnings .warn (
541- f"Field '{ name } ' (type { param . annotation } ) without default value is missing in input: "
452+ f"Field '{ name } ' (type { field_info . type_hint } ) without default value is missing in input: "
542453 f"{ '' .join (field_path )} . Auto-assigning default value: { auto_default } " ,
543454 UserWarning ,
544455 stacklevel = 4 ,
545456 )
546457 return lambda _ : auto_default
547458
548459 raise ValueError (
549- f"Field '{ name } ' (type { param . annotation } ) without default value is missing in input: { '' .join (field_path )} "
460+ f"Field '{ name } ' (type { field_info . type_hint } ) without default value is missing in input: { '' .join (field_path )} "
550461 )
551462
552- field_value_decoder = [
553- make_closure_for_field (name , param ) for (name , param ) in parameters .items ()
554- ]
555-
556463 # Different construction for different struct types
557464 if is_pydantic_model (dst_struct_type ):
558465 # Pydantic models prefer keyword arguments
559- field_names = list (parameters .keys ())
466+ pydantic_fields_decoder = [
467+ (field_info .name , make_closure_for_field (field_info ))
468+ for field_info in dst_type_variant .fields
469+ ]
560470 return lambda values : dst_struct_type (
561471 ** {
562- field_names [ i ] : decoder (values )
563- for i , decoder in enumerate ( field_value_decoder )
472+ field_name : decoder (values )
473+ for field_name , decoder in pydantic_fields_decoder
564474 }
565475 )
566476 else :
477+ struct_fields_decoder = [
478+ make_closure_for_field (field_info ) for field_info in dst_type_variant .fields
479+ ]
567480 # Dataclasses and NamedTuples can use positional arguments
568481 return lambda values : dst_struct_type (
569- * (decoder (values ) for decoder in field_value_decoder )
482+ * (decoder (values ) for decoder in struct_fields_decoder )
570483 )
571484
572485
0 commit comments