11""" Functions for decoding dataclass fields from "raw" values (e.g. from json).
22"""
3+ from __future__ import annotations
4+
35import inspect
46import warnings
57from collections import OrderedDict
911from functools import lru_cache , partial
1012from logging import getLogger
1113from pathlib import Path
12- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union
14+ from typing import Any , Callable , TypeVar
1315
1416from simple_parsing .annotation_utils .get_field_annotations import (
1517 evaluate_string_annotation ,
1618)
1719from simple_parsing .utils import (
1820 get_bound ,
21+ get_forward_arg ,
1922 get_type_arguments ,
23+ is_dataclass_type ,
2024 is_dict ,
2125 is_enum ,
2226 is_forward_ref ,
3539V = TypeVar ("V" )
3640
3741# Dictionary mapping from types/type annotations to their decoding functions.
38- _decoding_fns : Dict [ Type [T ], Callable [[Any ], T ]] = {
42+ _decoding_fns : dict [ type [T ], Callable [[Any ], T ]] = {
3943 # the 'primitive' types are decoded using the type fn as a constructor.
4044 t : t
4145 for t in [str , float , int , bytes ]
@@ -51,7 +55,7 @@ def decode_bool(v: Any) -> bool:
5155_decoding_fns [bool ] = decode_bool
5256
5357
54- def decode_field (field : Field , raw_value : Any , containing_dataclass : Optional [ type ] = None ) -> Any :
58+ def decode_field (field : Field , raw_value : Any , containing_dataclass : type | None = None ) -> Any :
5559 """Converts a "raw" value (e.g. from json file) to the type of the `field`.
5660
5761 When serializing a dataclass to json, all objects are converted to dicts.
@@ -84,7 +88,7 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[ty
8488
8589
8690@lru_cache (maxsize = 100 )
87- def get_decoding_fn (t : Type [T ]) -> Callable [[ Any ] , T ]:
91+ def get_decoding_fn (type_annotation : type [T ] | str ) -> Callable [... , T ]:
8892 """Fetches/Creates a decoding function for the given type annotation.
8993
9094 This decoding function can then be used to create an instance of the type
@@ -111,67 +115,54 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
111115 A function that decodes a 'raw' value to an instance of type `t`.
112116
113117 """
114- # cache_info = get_decoding_fn.cache_info()
115- # logger.debug(f"called for type {t}! Cache info: {cache_info}")
116-
117- def _get_potential_keys (annotation : str ) -> List [str ]:
118- # Type annotation is a string.
119- # This can happen when the `from __future__ import annotations` feature is used.
120- potential_keys : List [Type ] = []
121- for key in _decoding_fns :
122- if inspect .isclass (key ):
123- if key .__qualname__ == annotation :
124- # Qualname is more specific, there can't possibly be another match, so break.
125- potential_keys .append (key )
126- break
127- if key .__qualname__ == annotation :
128- # For just __name__, there could be more than one match.
129- potential_keys .append (key )
130- return potential_keys
131-
132- if isinstance (t , str ):
133- if t in _decoding_fns :
134- return _decoding_fns [t ]
135-
136- potential_keys = _get_potential_keys (t )
137-
138- if not potential_keys :
139- # Try to replace the new-style annotation str with the old style syntax, and see if we
140- # find a match.
141- # try:
142- try :
143- evaluated_t = evaluate_string_annotation (t )
144- # NOTE: We now have a 'live'/runtime type annotation object from the typing module.
145- except (ValueError , TypeError ) as err :
146- logger .error (f"Unable to evaluate the type annotation string { t } : { err } ." )
147- else :
148- if evaluated_t in _decoding_fns :
149- return _decoding_fns [evaluated_t ]
150- # If we still don't have this annotation stored in our dict of known functions, we
151- # recurse, to try to deconstruct this annotation into its parts, and construct the
152- # decoding function for the annotation. If this doesn't work, we just raise the
153- # errors.
154- return get_decoding_fn (evaluated_t )
155-
156- raise ValueError (
157- f"Couldn't find a decoding function for the string annotation '{ t } '.\n "
158- f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
159- f"to work on fixing it.\n "
160- f"Types with a known decoding function: { list (_decoding_fns .keys ())} "
118+ from .serializable import from_dict
119+
120+ logger .debug (f"Getting the decoding function for { type_annotation !r} " )
121+
122+ if isinstance (type_annotation , str ):
123+ # Check first if there are any matching registered decoding functions.
124+ # TODO: Might be better to actually use the scope of the field, right?
125+ matching_entries = {
126+ key : decoding_fn
127+ for key , decoding_fn in _decoding_fns .items ()
128+ if (inspect .isclass (key ) and key .__name__ == type_annotation )
129+ }
130+ if len (matching_entries ) == 1 :
131+ _ , decoding_fn = matching_entries .popitem ()
132+ return decoding_fn
133+ elif len (matching_entries ) > 1 :
134+ # Multiple decoding functions match the type. Can't tell.
135+ logger .warning (
136+ RuntimeWarning (
137+ f"More than one potential decoding functions were found for types that match "
138+ f"the string annotation { type_annotation !r} . This will simply try each one "
139+ f"and return the first one that works."
140+ )
161141 )
162- if len (potential_keys ) == 1 :
163- t = potential_keys [0 ]
142+ return try_functions (* (decoding_fn for _ , decoding_fn in matching_entries .items ()))
164143 else :
165- raise ValueError (
166- f"Multiple decoding functions registered for a type { t } : { potential_keys } \n "
167- f"This could be a bug, but try to use different names for each type, or add the "
168- f"modules they come from as a prefix, perhaps?"
169- )
144+ # Try to evaluate the string annotation.
145+ t = evaluate_string_annotation (type_annotation )
146+
147+ elif is_forward_ref (type_annotation ):
148+ forward_arg : str = get_forward_arg (type_annotation )
149+ # Recurse until we've resolved the forward reference.
150+ return get_decoding_fn (forward_arg )
151+
152+ else :
153+ t = type_annotation
154+
155+ logger .debug (f"{ type_annotation !r} -> { t !r} " )
156+
157+ # T should now be a type or one of the objects from the typing module.
170158
171159 if t in _decoding_fns :
172160 # The type has a dedicated decoding function.
173161 return _decoding_fns [t ]
174162
163+ if is_dataclass_type (t ):
164+ return partial (from_dict , t )
165+
175166 if t is Any :
176167 logger .debug (f"Decoding an Any type: { t } " )
177168 return no_op
@@ -214,31 +205,6 @@ def _get_potential_keys(annotation: str) -> List[str]:
214205 logger .debug (f"Decoding an Enum field: { t } " )
215206 return decode_enum (t )
216207
217- from .serializable import SerializableMixin , get_dataclass_types_from_forward_ref
218-
219- if is_forward_ref (t ):
220- dcs = get_dataclass_types_from_forward_ref (t )
221- if len (dcs ) == 1 :
222- dc = dcs [0 ]
223- return dc .from_dict
224- if len (dcs ) > 1 :
225- logger .warning (
226- RuntimeWarning (
227- f"More than one potential Serializable dataclass was found with a name matching "
228- f"the type annotation { t } . This will simply try each one, and return the "
229- f"first one that works. Potential classes: { dcs } "
230- )
231- )
232- return try_functions (* [partial (dc .from_dict , drop_extra_fields = False ) for dc in dcs ])
233- else :
234- # No idea what the forward ref refers to!
235- logger .warning (
236- f"Unable to find a dataclass that matches the forward ref { t } inside the "
237- f"registered { SerializableMixin } subclasses. Leaving the value as-is."
238- f"(Consider using Serializable or FrozenSerializable as a base class?)."
239- )
240- return no_op
241-
242208 if is_typevar (t ):
243209 bound = get_bound (t )
244210 logger .debug (f"Decoding a typevar: { t } , bound type is { bound } ." )
@@ -256,31 +222,31 @@ def _get_potential_keys(annotation: str) -> List[str]:
256222 return try_constructor (t )
257223
258224
259- def _register (t : Type , func : Callable ) -> None :
225+ def _register (t : type , func : Callable ) -> None :
260226 if t not in _decoding_fns :
261227 # logger.debug(f"Registering the type {t} with decoding function {func}")
262228 _decoding_fns [t ] = func
263229
264230
265- def register_decoding_fn (some_type : Type [T ], function : Callable [[Any ], T ]) -> None :
231+ def register_decoding_fn (some_type : type [T ], function : Callable [[Any ], T ]) -> None :
266232 """Register a decoding function for the type `some_type`."""
267233 _register (some_type , function )
268234
269235
270- def decode_optional (t : Type [T ]) -> Callable [[Optional [ Any ]], Optional [ T ] ]:
236+ def decode_optional (t : type [T ]) -> Callable [[Any | None ], T | None ]:
271237 decode = get_decoding_fn (t )
272238
273- def _decode_optional (val : Optional [ Any ] ) -> Optional [ T ] :
239+ def _decode_optional (val : Any | None ) -> T | None :
274240 return val if val is None else decode (val )
275241
276242 return _decode_optional
277243
278244
279- def try_functions (* funcs : Callable [[Any ], T ]) -> Callable [[Any ], Union [ T , Any ] ]:
245+ def try_functions (* funcs : Callable [[Any ], T ]) -> Callable [[Any ], T | Any ]:
280246 """Tries to use the functions in succession, else returns the same value unchanged."""
281247
282- def _try_functions (val : Any ) -> Union [ T , Any ] :
283- e : Optional [ Exception ] = None
248+ def _try_functions (val : Any ) -> T | Any :
249+ e : Exception | None = None
284250 for func in funcs :
285251 try :
286252 return func (val )
@@ -293,30 +259,30 @@ def _try_functions(val: Any) -> Union[T, Any]:
293259 return _try_functions
294260
295261
296- def decode_union (* types : Type [T ]) -> Callable [[Any ], Union [ T , Any ] ]:
262+ def decode_union (* types : type [T ]) -> Callable [[Any ], T | Any ]:
297263 types = list (types )
298264 optional = type (None ) in types
299265 # Partition the Union into None and non-None types.
300266 while type (None ) in types :
301267 types .remove (type (None ))
302268
303- decoding_fns : List [Callable [[Any ], T ]] = [
269+ decoding_fns : list [Callable [[Any ], T ]] = [
304270 decode_optional (t ) if optional else get_decoding_fn (t ) for t in types
305271 ]
306272 # Try using each of the non-None types, in succession. Worst case, return the value.
307273 return try_functions (* decoding_fns )
308274
309275
310- def decode_list (t : Type [T ]) -> Callable [[List [Any ]], List [T ]]:
276+ def decode_list (t : type [T ]) -> Callable [[list [Any ]], list [T ]]:
311277 decode_item = get_decoding_fn (t )
312278
313- def _decode_list (val : List [Any ]) -> List [T ]:
279+ def _decode_list (val : list [Any ]) -> list [T ]:
314280 return [decode_item (v ) for v in val ]
315281
316282 return _decode_list
317283
318284
319- def decode_tuple (* tuple_item_types : Type [T ]) -> Callable [[List [T ]], Tuple [T , ...]]:
285+ def decode_tuple (* tuple_item_types : type [T ]) -> Callable [[list [T ]], tuple [T , ...]]:
320286 """Makes a parsing function for creating tuples.
321287
322288 Can handle tuples with different item types, for instance:
@@ -338,7 +304,7 @@ def decode_tuple(*tuple_item_types: Type[T]) -> Callable[[List[T]], Tuple[T, ...
338304 # Note, if there are more values than types in the tuple type, then the
339305 # last type is used.
340306
341- def _decode_tuple (val : Tuple [Any , ...]) -> Tuple [T , ...]:
307+ def _decode_tuple (val : tuple [Any , ...]) -> tuple [T , ...]:
342308 if has_ellipsis :
343309 return tuple (decoding_fn (v ) for v in val )
344310 else :
@@ -347,7 +313,7 @@ def _decode_tuple(val: Tuple[Any, ...]) -> Tuple[T, ...]:
347313 return _decode_tuple
348314
349315
350- def decode_set (item_type : Type [T ]) -> Callable [[List [T ]], Set [T ]]:
316+ def decode_set (item_type : type [T ]) -> Callable [[list [T ]], set [T ]]:
351317 """Makes a parsing function for creating sets with items of type `item_type`.
352318
353319 Args:
@@ -359,13 +325,13 @@ def decode_set(item_type: Type[T]) -> Callable[[List[T]], Set[T]]:
359325 # Get the parse fn for a list of items of type `item_type`.
360326 parse_list_fn = decode_list (item_type )
361327
362- def _decode_set (val : List [Any ]) -> Set [T ]:
328+ def _decode_set (val : list [Any ]) -> set [T ]:
363329 return set (parse_list_fn (val ))
364330
365331 return _decode_set
366332
367333
368- def decode_dict (K_ : Type [K ], V_ : Type [V ]) -> Callable [[List [ Tuple [Any , Any ]]], Dict [K , V ]]:
334+ def decode_dict (K_ : type [K ], V_ : type [V ]) -> Callable [[list [ tuple [Any , Any ]]], dict [K , V ]]:
369335 """Creates a decoding function for a dict type. Works with OrderedDict too.
370336
371337 Args:
@@ -379,8 +345,8 @@ def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], D
379345 decode_k = get_decoding_fn (K_ )
380346 decode_v = get_decoding_fn (V_ )
381347
382- def _decode_dict (val : Union [ Dict [ Any , Any ], List [ Tuple [Any , Any ]]] ) -> Dict [K , V ]:
383- result : Dict [K , V ] = {}
348+ def _decode_dict (val : dict [ Any , Any ] | list [ tuple [Any , Any ]]) -> dict [K , V ]:
349+ result : dict [K , V ] = {}
384350 if isinstance (val , list ):
385351 result = OrderedDict ()
386352 items = val
@@ -399,7 +365,7 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V
399365 return _decode_dict
400366
401367
402- def decode_enum (item_type : Type [Enum ]) -> Callable [[str ], Enum ]:
368+ def decode_enum (item_type : type [Enum ]) -> Callable [[str ], Enum ]:
403369 """
404370 Creates a decoding function for an enum type.
405371
@@ -428,7 +394,7 @@ def no_op(v: T) -> T:
428394 return v
429395
430396
431- def try_constructor (t : Type [T ]) -> Callable [[Any ], Union [ T , Any ] ]:
397+ def try_constructor (t : type [T ]) -> Callable [[Any ], T | Any ]:
432398 """Tries to use the type as a constructor. If that fails, returns the value as-is.
433399
434400 Args:
0 commit comments