66# --------------------------------------------------------------------------
77# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
88
9+ import copy
910import calendar
1011import decimal
1112import functools
1213import sys
1314import logging
1415import base64
1516import re
16- import copy
1717import typing
1818import enum
1919import email .utils
@@ -339,7 +339,7 @@ def _get_model(module_name: str, model_name: str):
339339
340340class _MyMutableMapping (MutableMapping [str , typing .Any ]): # pylint: disable=unsubscriptable-object
341341 def __init__ (self , data : typing .Dict [str , typing .Any ]) -> None :
342- self ._data = copy . deepcopy ( data )
342+ self ._data = data
343343
344344 def __contains__ (self , key : typing .Any ) -> bool :
345345 return key in self ._data
@@ -378,16 +378,13 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
378378 return default
379379
380380 @typing .overload
381- def pop (self , key : str ) -> typing .Any :
382- ...
381+ def pop (self , key : str ) -> typing .Any : ...
383382
384383 @typing .overload
385- def pop (self , key : str , default : _T ) -> _T :
386- ...
384+ def pop (self , key : str , default : _T ) -> _T : ...
387385
388386 @typing .overload
389- def pop (self , key : str , default : typing .Any ) -> typing .Any :
390- ...
387+ def pop (self , key : str , default : typing .Any ) -> typing .Any : ...
391388
392389 def pop (self , key : str , default : typing .Any = _UNSET ) -> typing .Any :
393390 if default is _UNSET :
@@ -404,12 +401,10 @@ def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
404401 self ._data .update (* args , ** kwargs )
405402
406403 @typing .overload
407- def setdefault (self , key : str , default : None = None ) -> None :
408- ...
404+ def setdefault (self , key : str , default : None = None ) -> None : ...
409405
410406 @typing .overload
411- def setdefault (self , key : str , default : typing .Any ) -> typing .Any :
412- ...
407+ def setdefault (self , key : str , default : typing .Any ) -> typing .Any : ...
413408
414409 def setdefault (self , key : str , default : typing .Any = _UNSET ) -> typing .Any :
415410 if default is _UNSET :
@@ -594,6 +589,64 @@ def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
594589 return v .as_dict (exclude_readonly = exclude_readonly ) if hasattr (v , "as_dict" ) else v
595590
596591
592+ def _deserialize_model (model_deserializer : typing .Optional [typing .Callable ], obj ):
593+ if _is_model (obj ):
594+ return obj
595+ return _deserialize (model_deserializer , obj )
596+
597+
598+ def _deserialize_with_optional (if_obj_deserializer : typing .Optional [typing .Callable ], obj ):
599+ if obj is None :
600+ return obj
601+ return _deserialize_with_callable (if_obj_deserializer , obj )
602+
603+
604+ def _deserialize_with_union (deserializers , obj ):
605+ for deserializer in deserializers :
606+ try :
607+ return _deserialize (deserializer , obj )
608+ except DeserializationError :
609+ pass
610+ raise DeserializationError ()
611+
612+
613+ def _deserialize_dict (
614+ value_deserializer : typing .Optional [typing .Callable ],
615+ module : typing .Optional [str ],
616+ obj : typing .Dict [typing .Any , typing .Any ],
617+ ):
618+ if obj is None :
619+ return obj
620+ return {k : _deserialize (value_deserializer , v , module ) for k , v in obj .items ()}
621+
622+
623+ def _deserialize_multiple_sequence (
624+ entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
625+ module : typing .Optional [str ],
626+ obj ,
627+ ):
628+ if obj is None :
629+ return obj
630+ return type (obj )(_deserialize (deserializer , entry , module ) for entry , deserializer in zip (obj , entry_deserializers ))
631+
632+
633+ def _deserialize_sequence (
634+ deserializer : typing .Optional [typing .Callable ],
635+ module : typing .Optional [str ],
636+ obj ,
637+ ):
638+ if obj is None :
639+ return obj
640+ return type (obj )(_deserialize (deserializer , entry , module ) for entry in obj )
641+
642+
643+ def _sorted_annotations (types : typing .List [typing .Any ]) -> typing .List [typing .Any ]:
644+ return sorted (
645+ types ,
646+ key = lambda x : hasattr (x , "__name__" ) and x .__name__ .lower () in ("str" , "float" , "int" , "bool" ),
647+ )
648+
649+
597650def _get_deserialize_callable_from_annotation ( # pylint: disable=R0911, R0915, R0912
598651 annotation : typing .Any ,
599652 module : typing .Optional [str ],
@@ -621,11 +674,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
621674 if rf :
622675 rf ._is_model = True
623676
624- def _deserialize_model (model_deserializer : typing .Optional [typing .Callable ], obj ):
625- if _is_model (obj ):
626- return obj
627- return _deserialize (model_deserializer , obj )
628-
629677 return functools .partial (_deserialize_model , annotation ) # pyright: ignore
630678 except Exception :
631679 pass
@@ -640,36 +688,27 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
640688 # is it optional?
641689 try :
642690 if any (a for a in annotation .__args__ if a == type (None )): # pyright: ignore
643- if_obj_deserializer = _get_deserialize_callable_from_annotation (
644- next ( a for a in annotation . __args__ if a != type ( None )), module , rf # pyright: ignore
645- )
646-
647- def _deserialize_with_optional ( if_obj_deserializer : typing . Optional [ typing . Callable ], obj ):
648- if obj is None :
649- return obj
650- return _deserialize_with_callable ( if_obj_deserializer , obj )
651-
652- return functools . partial ( _deserialize_with_optional , if_obj_deserializer )
691+ if len ( annotation . __args__ ) <= 2 : # pyright: ignore
692+ if_obj_deserializer = _get_deserialize_callable_from_annotation (
693+ next ( a for a in annotation . __args__ if a != type ( None )), module , rf # pyright: ignore
694+ )
695+
696+ return functools . partial ( _deserialize_with_optional , if_obj_deserializer )
697+ # the type is Optional[Union[...]], we need to remove the None type from the Union
698+ annotation_copy = copy . copy ( annotation )
699+ annotation_copy . __args__ = [ a for a in annotation_copy . __args__ if a != type ( None )] # pyright: ignore
700+ return _get_deserialize_callable_from_annotation ( annotation_copy , module , rf )
653701 except AttributeError :
654702 pass
655703
704+ # is it union?
656705 if getattr (annotation , "__origin__" , None ) is typing .Union :
657706 # initial ordering is we make `string` the last deserialization option, because it is often them most generic
658707 deserializers = [
659708 _get_deserialize_callable_from_annotation (arg , module , rf )
660- for arg in sorted (
661- annotation .__args__ , key = lambda x : hasattr (x , "__name__" ) and x .__name__ == "str" # pyright: ignore
662- )
709+ for arg in _sorted_annotations (annotation .__args__ ) # pyright: ignore
663710 ]
664711
665- def _deserialize_with_union (deserializers , obj ):
666- for deserializer in deserializers :
667- try :
668- return _deserialize (deserializer , obj )
669- except DeserializationError :
670- pass
671- raise DeserializationError ()
672-
673712 return functools .partial (_deserialize_with_union , deserializers )
674713
675714 try :
@@ -678,53 +717,27 @@ def _deserialize_with_union(deserializers, obj):
678717 annotation .__args__ [1 ], module , rf # pyright: ignore
679718 )
680719
681- def _deserialize_dict (
682- value_deserializer : typing .Optional [typing .Callable ],
683- obj : typing .Dict [typing .Any , typing .Any ],
684- ):
685- if obj is None :
686- return obj
687- return {k : _deserialize (value_deserializer , v , module ) for k , v in obj .items ()}
688-
689720 return functools .partial (
690721 _deserialize_dict ,
691722 value_deserializer ,
723+ module ,
692724 )
693725 except (AttributeError , IndexError ):
694726 pass
695727 try :
696728 if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]: # pyright: ignore
697729 if len (annotation .__args__ ) > 1 : # pyright: ignore
698730
699- def _deserialize_multiple_sequence (
700- entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
701- obj ,
702- ):
703- if obj is None :
704- return obj
705- return type (obj )(
706- _deserialize (deserializer , entry , module )
707- for entry , deserializer in zip (obj , entry_deserializers )
708- )
709-
710731 entry_deserializers = [
711732 _get_deserialize_callable_from_annotation (dt , module , rf )
712733 for dt in annotation .__args__ # pyright: ignore
713734 ]
714- return functools .partial (_deserialize_multiple_sequence , entry_deserializers )
735+ return functools .partial (_deserialize_multiple_sequence , entry_deserializers , module )
715736 deserializer = _get_deserialize_callable_from_annotation (
716737 annotation .__args__ [0 ], module , rf # pyright: ignore
717738 )
718739
719- def _deserialize_sequence (
720- deserializer : typing .Optional [typing .Callable ],
721- obj ,
722- ):
723- if obj is None :
724- return obj
725- return type (obj )(_deserialize (deserializer , entry , module ) for entry in obj )
726-
727- return functools .partial (_deserialize_sequence , deserializer )
740+ return functools .partial (_deserialize_sequence , deserializer , module )
728741 except (TypeError , IndexError , AttributeError , SyntaxError ):
729742 pass
730743
0 commit comments