6
6
# --------------------------------------------------------------------------
7
7
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
8
9
+ import copy
9
10
import calendar
10
11
import decimal
11
12
import functools
12
13
import sys
13
14
import logging
14
15
import base64
15
16
import re
16
- import copy
17
17
import typing
18
18
import enum
19
19
import email .utils
@@ -339,7 +339,7 @@ def _get_model(module_name: str, model_name: str):
339
339
340
340
class _MyMutableMapping (MutableMapping [str , typing .Any ]): # pylint: disable=unsubscriptable-object
341
341
def __init__ (self , data : typing .Dict [str , typing .Any ]) -> None :
342
- self ._data = copy . deepcopy ( data )
342
+ self ._data = data
343
343
344
344
def __contains__ (self , key : typing .Any ) -> bool :
345
345
return key in self ._data
@@ -378,16 +378,13 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
378
378
return default
379
379
380
380
@typing .overload
381
- def pop (self , key : str ) -> typing .Any :
382
- ...
381
+ def pop (self , key : str ) -> typing .Any : ...
383
382
384
383
@typing .overload
385
- def pop (self , key : str , default : _T ) -> _T :
386
- ...
384
+ def pop (self , key : str , default : _T ) -> _T : ...
387
385
388
386
@typing .overload
389
- def pop (self , key : str , default : typing .Any ) -> typing .Any :
390
- ...
387
+ def pop (self , key : str , default : typing .Any ) -> typing .Any : ...
391
388
392
389
def pop (self , key : str , default : typing .Any = _UNSET ) -> typing .Any :
393
390
if default is _UNSET :
@@ -404,12 +401,10 @@ def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
404
401
self ._data .update (* args , ** kwargs )
405
402
406
403
@typing .overload
407
- def setdefault (self , key : str , default : None = None ) -> None :
408
- ...
404
+ def setdefault (self , key : str , default : None = None ) -> None : ...
409
405
410
406
@typing .overload
411
- def setdefault (self , key : str , default : typing .Any ) -> typing .Any :
412
- ...
407
+ def setdefault (self , key : str , default : typing .Any ) -> typing .Any : ...
413
408
414
409
def setdefault (self , key : str , default : typing .Any = _UNSET ) -> typing .Any :
415
410
if default is _UNSET :
@@ -594,6 +589,64 @@ def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
594
589
return v .as_dict (exclude_readonly = exclude_readonly ) if hasattr (v , "as_dict" ) else v
595
590
596
591
592
+ def _deserialize_model (model_deserializer : typing .Optional [typing .Callable ], obj ):
593
+ if _is_model (obj ):
594
+ return obj
595
+ return _deserialize (model_deserializer , obj )
596
+
597
+
598
+ def _deserialize_with_optional (if_obj_deserializer : typing .Optional [typing .Callable ], obj ):
599
+ if obj is None :
600
+ return obj
601
+ return _deserialize_with_callable (if_obj_deserializer , obj )
602
+
603
+
604
+ def _deserialize_with_union (deserializers , obj ):
605
+ for deserializer in deserializers :
606
+ try :
607
+ return _deserialize (deserializer , obj )
608
+ except DeserializationError :
609
+ pass
610
+ raise DeserializationError ()
611
+
612
+
613
+ def _deserialize_dict (
614
+ value_deserializer : typing .Optional [typing .Callable ],
615
+ module : typing .Optional [str ],
616
+ obj : typing .Dict [typing .Any , typing .Any ],
617
+ ):
618
+ if obj is None :
619
+ return obj
620
+ return {k : _deserialize (value_deserializer , v , module ) for k , v in obj .items ()}
621
+
622
+
623
+ def _deserialize_multiple_sequence (
624
+ entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
625
+ module : typing .Optional [str ],
626
+ obj ,
627
+ ):
628
+ if obj is None :
629
+ return obj
630
+ return type (obj )(_deserialize (deserializer , entry , module ) for entry , deserializer in zip (obj , entry_deserializers ))
631
+
632
+
633
+ def _deserialize_sequence (
634
+ deserializer : typing .Optional [typing .Callable ],
635
+ module : typing .Optional [str ],
636
+ obj ,
637
+ ):
638
+ if obj is None :
639
+ return obj
640
+ return type (obj )(_deserialize (deserializer , entry , module ) for entry in obj )
641
+
642
+
643
+ def _sorted_annotations (types : typing .List [typing .Any ]) -> typing .List [typing .Any ]:
644
+ return sorted (
645
+ types ,
646
+ key = lambda x : hasattr (x , "__name__" ) and x .__name__ .lower () in ("str" , "float" , "int" , "bool" ),
647
+ )
648
+
649
+
597
650
def _get_deserialize_callable_from_annotation ( # pylint: disable=R0911, R0915, R0912
598
651
annotation : typing .Any ,
599
652
module : typing .Optional [str ],
@@ -621,11 +674,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
621
674
if rf :
622
675
rf ._is_model = True
623
676
624
- def _deserialize_model (model_deserializer : typing .Optional [typing .Callable ], obj ):
625
- if _is_model (obj ):
626
- return obj
627
- return _deserialize (model_deserializer , obj )
628
-
629
677
return functools .partial (_deserialize_model , annotation ) # pyright: ignore
630
678
except Exception :
631
679
pass
@@ -640,36 +688,27 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
640
688
# is it optional?
641
689
try :
642
690
if any (a for a in annotation .__args__ if a == type (None )): # pyright: ignore
643
- if_obj_deserializer = _get_deserialize_callable_from_annotation (
644
- next ( a for a in annotation . __args__ if a != type ( None )), module , rf # pyright: ignore
645
- )
646
-
647
- def _deserialize_with_optional ( if_obj_deserializer : typing . Optional [ typing . Callable ], obj ):
648
- if obj is None :
649
- return obj
650
- return _deserialize_with_callable ( if_obj_deserializer , obj )
651
-
652
- return functools . partial ( _deserialize_with_optional , if_obj_deserializer )
691
+ if len ( annotation . __args__ ) <= 2 : # pyright: ignore
692
+ if_obj_deserializer = _get_deserialize_callable_from_annotation (
693
+ next ( a for a in annotation . __args__ if a != type ( None )), module , rf # pyright: ignore
694
+ )
695
+
696
+ return functools . partial ( _deserialize_with_optional , if_obj_deserializer )
697
+ # the type is Optional[Union[...]], we need to remove the None type from the Union
698
+ annotation_copy = copy . copy ( annotation )
699
+ annotation_copy . __args__ = [ a for a in annotation_copy . __args__ if a != type ( None )] # pyright: ignore
700
+ return _get_deserialize_callable_from_annotation ( annotation_copy , module , rf )
653
701
except AttributeError :
654
702
pass
655
703
704
+ # is it union?
656
705
if getattr (annotation , "__origin__" , None ) is typing .Union :
657
706
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
658
707
deserializers = [
659
708
_get_deserialize_callable_from_annotation (arg , module , rf )
660
- for arg in sorted (
661
- annotation .__args__ , key = lambda x : hasattr (x , "__name__" ) and x .__name__ == "str" # pyright: ignore
662
- )
709
+ for arg in _sorted_annotations (annotation .__args__ ) # pyright: ignore
663
710
]
664
711
665
- def _deserialize_with_union (deserializers , obj ):
666
- for deserializer in deserializers :
667
- try :
668
- return _deserialize (deserializer , obj )
669
- except DeserializationError :
670
- pass
671
- raise DeserializationError ()
672
-
673
712
return functools .partial (_deserialize_with_union , deserializers )
674
713
675
714
try :
@@ -678,53 +717,27 @@ def _deserialize_with_union(deserializers, obj):
678
717
annotation .__args__ [1 ], module , rf # pyright: ignore
679
718
)
680
719
681
- def _deserialize_dict (
682
- value_deserializer : typing .Optional [typing .Callable ],
683
- obj : typing .Dict [typing .Any , typing .Any ],
684
- ):
685
- if obj is None :
686
- return obj
687
- return {k : _deserialize (value_deserializer , v , module ) for k , v in obj .items ()}
688
-
689
720
return functools .partial (
690
721
_deserialize_dict ,
691
722
value_deserializer ,
723
+ module ,
692
724
)
693
725
except (AttributeError , IndexError ):
694
726
pass
695
727
try :
696
728
if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]: # pyright: ignore
697
729
if len (annotation .__args__ ) > 1 : # pyright: ignore
698
730
699
- def _deserialize_multiple_sequence (
700
- entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
701
- obj ,
702
- ):
703
- if obj is None :
704
- return obj
705
- return type (obj )(
706
- _deserialize (deserializer , entry , module )
707
- for entry , deserializer in zip (obj , entry_deserializers )
708
- )
709
-
710
731
entry_deserializers = [
711
732
_get_deserialize_callable_from_annotation (dt , module , rf )
712
733
for dt in annotation .__args__ # pyright: ignore
713
734
]
714
- return functools .partial (_deserialize_multiple_sequence , entry_deserializers )
735
+ return functools .partial (_deserialize_multiple_sequence , entry_deserializers , module )
715
736
deserializer = _get_deserialize_callable_from_annotation (
716
737
annotation .__args__ [0 ], module , rf # pyright: ignore
717
738
)
718
739
719
- def _deserialize_sequence (
720
- deserializer : typing .Optional [typing .Callable ],
721
- obj ,
722
- ):
723
- if obj is None :
724
- return obj
725
- return type (obj )(_deserialize (deserializer , entry , module ) for entry in obj )
726
-
727
- return functools .partial (_deserialize_sequence , deserializer )
740
+ return functools .partial (_deserialize_sequence , deserializer , module )
728
741
except (TypeError , IndexError , AttributeError , SyntaxError ):
729
742
pass
730
743
0 commit comments