33"""
44
55from decimal import Decimal
6- from typing import Any , Dict , Optional
6+ from typing import Any , Dict , Optional , Union
77
88import numpy as np
9+ from pydantic import (
10+ BaseModel ,
11+ ConfigDict ,
12+ SerializationInfo ,
13+ SerializerFunctionWrapHandler ,
14+ WrapSerializer ,
15+ field_validator ,
16+ model_serializer ,
17+ )
18+ from typing_extensions import Annotated
19+
20+
21+ def reduce_complex (data ):
22+ # Reduce Complex
23+ if isinstance (data , complex ):
24+ return [data .real , data .imag ]
25+ # Fallback
26+ return data
27+
28+
29+ def keep_decimal_cast_ndarray_complex (
30+ v : Any , nxt : SerializerFunctionWrapHandler , info : SerializationInfo
31+ ) -> Union [list , Decimal , float ]:
32+ """
33+ Ensure Decimal types are preserved on the way out
34+
35+ This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
36+ https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
37+
38+
39+ This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode
40+ """
41+ if isinstance (v , Decimal ):
42+ return v
43+ if info .mode == "json" :
44+ if isinstance (v , complex ):
45+ return nxt (reduce_complex (v ))
46+ if isinstance (v , np .ndarray ):
47+ # Handle NDArray and complex NDArray
48+ flat_list = v .flatten ().tolist ()
49+ reduced_list = list (map (reduce_complex , flat_list ))
50+ return nxt (reduced_list )
51+ try :
52+ # Cast NumPy scalar data types to native Python data type
53+ v = v .item ()
54+ except (AttributeError , ValueError ):
55+ pass
56+ return nxt (v )
57+
958
10- try :
11- from pydantic .v1 import BaseModel , validator
12- except ImportError : # Will also trap ModuleNotFoundError
13- from pydantic import BaseModel , validator
59+ # Only 1 serializer is allowed. You can't chain wrap serializers.
60+ AnyArrayComplex = Annotated [Any , WrapSerializer (keep_decimal_cast_ndarray_complex )]
1461
1562
1663class Datum (BaseModel ):
@@ -38,15 +85,15 @@ class Datum(BaseModel):
3885 numeric : bool
3986 label : str
4087 units : str
41- data : Any
88+ data : AnyArrayComplex
4289 comment : str = ""
4390 doi : Optional [str ] = None
4491 glossary : str = ""
4592
46- class Config :
47- extra = "forbid"
48- allow_mutation = False
49- json_encoders = { np . ndarray : lambda v : v . flatten (). tolist (), complex : lambda v : ( v . real , v . imag )}
93+ model_config = ConfigDict (
94+ extra = "forbid" ,
95+ frozen = True ,
96+ )
5097
5198 def __init__ (self , label , units , data , * , comment = None , doi = None , glossary = None , numeric = True ):
5299 kwargs = {"label" : label , "units" : units , "data" : data , "numeric" : numeric }
@@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,
59106
60107 super ().__init__ (** kwargs )
61108
62- @validator ("data" )
63- def must_be_numerical (cls , v , values , ** kwargs ):
109+ @field_validator ("data" )
110+ @classmethod
111+ def must_be_numerical (cls , v , info ):
64112 try :
65113 1.0 * v
66114 except TypeError :
67115 try :
68116 Decimal ("1.0" ) * v
69117 except TypeError :
70- if values ["numeric" ]:
118+ if info . data ["numeric" ]:
71119 raise ValueError (f"Datum data should be float, Decimal, or np.ndarray, not { type (v )} ." )
72120 else :
73- values ["numeric" ] = True
121+ info . data ["numeric" ] = True
74122 else :
75- values ["numeric" ] = True
123+ info . data ["numeric" ] = True
76124
77125 return v
78126
@@ -90,8 +138,35 @@ def __str__(self, label=""):
90138 text .append ("-" * width )
91139 return "\n " .join (text )
92140
141+ @model_serializer (mode = "wrap" )
142+ def _serialize_model (self , handler ) -> Dict [str , Any ]:
143+ """
144+ Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested
145+ models and any model config options.
146+
147+ Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been
148+ dumped/de-pydantic-ized.
149+ """
150+
151+ # Get the default return, let the model_dump handle kwarg
152+ default_result = handler (self )
153+ # Exclude unset always
154+ output_dict = {key : value for key , value in default_result .items () if key in self .model_fields_set }
155+ return output_dict
156+
93157 def dict (self , * args , ** kwargs ):
94- return super ().dict (* args , ** {** kwargs , ** {"exclude_unset" : True }})
158+ """
159+ Passthrough to model_dump without deprecation warning
160+ exclude_unset is forced through the model_serializer
161+ """
162+ return super ().model_dump (* args , ** kwargs )
163+
164+ def json (self , * args , ** kwargs ):
165+ """
166+ Passthrough to model_dump_sjon without deprecation warning
167+ exclude_unset is forced through the model_serializer
168+ """
169+ return super ().model_dump_json (* args , ** kwargs )
95170
96171 def to_units (self , units = None ):
97172 from .physical_constants import constants
0 commit comments