5
5
# license information.
6
6
# --------------------------------------------------------------------------
7
7
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
- # pyright: reportGeneralTypeIssues=false
9
8
10
9
import calendar
10
+ import decimal
11
11
import functools
12
12
import sys
13
13
import logging
14
14
import base64
15
15
import re
16
16
import copy
17
17
import typing
18
- import email
18
+ import enum
19
+ import email .utils
19
20
from datetime import datetime , date , time , timedelta , timezone
20
21
from json import JSONEncoder
22
+ from typing_extensions import Self
21
23
import isodate
22
24
from azure .core .exceptions import DeserializationError
23
25
from azure .core import CaseInsensitiveEnumMeta
34
36
__all__ = ["SdkJSONEncoder" , "Model" , "rest_field" , "rest_discriminator" ]
35
37
36
38
TZ_UTC = timezone .utc
39
+ _T = typing .TypeVar ("_T" )
37
40
38
41
39
42
def _timedelta_as_isostr (td : timedelta ) -> str :
@@ -144,6 +147,8 @@ def default(self, o): # pylint: disable=too-many-return-statements
144
147
except TypeError :
145
148
if isinstance (o , _Null ):
146
149
return None
150
+ if isinstance (o , decimal .Decimal ):
151
+ return float (o )
147
152
if isinstance (o , (bytes , bytearray )):
148
153
return _serialize_bytes (o , self .format )
149
154
try :
@@ -239,7 +244,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
239
244
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
240
245
if isinstance (attr , date ):
241
246
return attr
242
- return isodate .parse_date (attr , defaultmonth = None , defaultday = None )
247
+ return isodate .parse_date (attr , defaultmonth = None , defaultday = None ) # type: ignore
243
248
244
249
245
250
def _deserialize_time (attr : typing .Union [str , time ]) -> time :
@@ -275,6 +280,12 @@ def _deserialize_duration(attr):
275
280
return isodate .parse_duration (attr )
276
281
277
282
283
+ def _deserialize_decimal (attr ):
284
+ if isinstance (attr , decimal .Decimal ):
285
+ return attr
286
+ return decimal .Decimal (str (attr ))
287
+
288
+
278
289
_DESERIALIZE_MAPPING = {
279
290
datetime : _deserialize_datetime ,
280
291
date : _deserialize_date ,
@@ -283,6 +294,7 @@ def _deserialize_duration(attr):
283
294
bytearray : _deserialize_bytes ,
284
295
timedelta : _deserialize_duration ,
285
296
typing .Any : lambda x : x ,
297
+ decimal .Decimal : _deserialize_decimal ,
286
298
}
287
299
288
300
_DESERIALIZE_MAPPING_WITHFORMAT = {
@@ -373,8 +385,12 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
373
385
except KeyError :
374
386
return default
375
387
376
- @typing .overload # type: ignore
377
- def pop (self , key : str ) -> typing .Any : # pylint: disable=no-member
388
+ @typing .overload
389
+ def pop (self , key : str ) -> typing .Any :
390
+ ...
391
+
392
+ @typing .overload
393
+ def pop (self , key : str , default : _T ) -> _T :
378
394
...
379
395
380
396
@typing .overload
@@ -395,8 +411,8 @@ def clear(self) -> None:
395
411
def update (self , * args : typing .Any , ** kwargs : typing .Any ) -> None :
396
412
self ._data .update (* args , ** kwargs )
397
413
398
- @typing .overload # type: ignore
399
- def setdefault (self , key : str ) -> typing . Any :
414
+ @typing .overload
415
+ def setdefault (self , key : str , default : None = None ) -> None :
400
416
...
401
417
402
418
@typing .overload
@@ -434,6 +450,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
434
450
return tuple (_serialize (x , format ) for x in o )
435
451
if isinstance (o , (bytes , bytearray )):
436
452
return _serialize_bytes (o , format )
453
+ if isinstance (o , decimal .Decimal ):
454
+ return float (o )
455
+ if isinstance (o , enum .Enum ):
456
+ return o .value
437
457
try :
438
458
# First try datetime.datetime
439
459
return _serialize_datetime (o , format )
@@ -458,7 +478,13 @@ def _get_rest_field(
458
478
459
479
460
480
def _create_value (rf : typing .Optional ["_RestField" ], value : typing .Any ) -> typing .Any :
461
- return _deserialize (rf ._type , value ) if (rf and rf ._is_model ) else _serialize (value , rf ._format if rf else None )
481
+ if not rf :
482
+ return _serialize (value , None )
483
+ if rf ._is_multipart_file_input :
484
+ return value
485
+ if rf ._is_model :
486
+ return _deserialize (rf ._type , value )
487
+ return _serialize (value , rf ._format )
462
488
463
489
464
490
class Model (_MyMutableMapping ):
@@ -494,7 +520,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
494
520
def copy (self ) -> "Model" :
495
521
return Model (self .__dict__ )
496
522
497
- def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> "Model" : # pylint: disable=unused-argument
523
+ def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> Self : # pylint: disable=unused-argument
498
524
# we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
499
525
mros = cls .__mro__ [:- 3 ][::- 1 ] # ignore model, dict, and object parents, and reverse the mro order
500
526
attr_to_rest_field : typing .Dict [str , _RestField ] = { # map attribute name to rest_field property
@@ -536,7 +562,7 @@ def _deserialize(cls, data, exist_discriminators):
536
562
exist_discriminators .append (discriminator )
537
563
mapped_cls = cls .__mapping__ .get (
538
564
data .get (discriminator ), cls
539
- ) # pylint: disable=no-member
565
+ ) # pyright: ignore # pylint: disable=no-member
540
566
if mapped_cls == cls :
541
567
return cls (data )
542
568
return mapped_cls ._deserialize (data , exist_discriminators ) # pylint: disable=protected-access
@@ -553,20 +579,25 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
553
579
if exclude_readonly :
554
580
readonly_props = [p ._rest_name for p in self ._attr_to_rest_field .values () if _is_readonly (p )]
555
581
for k , v in self .items ():
556
- if exclude_readonly and k in readonly_props : # pyright: ignore[reportUnboundVariable]
582
+ if exclude_readonly and k in readonly_props : # pyright: ignore
557
583
continue
558
- result [k ] = Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
584
+ is_multipart_file_input = False
585
+ try :
586
+ is_multipart_file_input = next (rf for rf in self ._attr_to_rest_field .values () if rf ._rest_name == k )._is_multipart_file_input
587
+ except StopIteration :
588
+ pass
589
+ result [k ] = v if is_multipart_file_input else Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
559
590
return result
560
591
561
592
@staticmethod
562
593
def _as_dict_value (v : typing .Any , exclude_readonly : bool = False ) -> typing .Any :
563
594
if v is None or isinstance (v , _Null ):
564
595
return None
565
596
if isinstance (v , (list , tuple , set )):
566
- return [
597
+ return type ( v )(
567
598
Model ._as_dict_value (x , exclude_readonly = exclude_readonly )
568
599
for x in v
569
- ]
600
+ )
570
601
if isinstance (v , dict ):
571
602
return {
572
603
dk : Model ._as_dict_value (dv , exclude_readonly = exclude_readonly )
@@ -607,29 +638,22 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
607
638
return obj
608
639
return _deserialize (model_deserializer , obj )
609
640
610
- return functools .partial (_deserialize_model , annotation )
641
+ return functools .partial (_deserialize_model , annotation ) # pyright: ignore
611
642
except Exception :
612
643
pass
613
644
614
645
# is it a literal?
615
646
try :
616
- if sys .version_info >= (3 , 8 ):
617
- from typing import (
618
- Literal ,
619
- ) # pylint: disable=no-name-in-module, ungrouped-imports
620
- else :
621
- from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports
622
-
623
- if annotation .__origin__ == Literal :
647
+ if annotation .__origin__ is typing .Literal : # pyright: ignore
624
648
return None
625
649
except AttributeError :
626
650
pass
627
651
628
652
# is it optional?
629
653
try :
630
- if any (a for a in annotation .__args__ if a == type (None )):
654
+ if any (a for a in annotation .__args__ if a == type (None )): # pyright: ignore
631
655
if_obj_deserializer = _get_deserialize_callable_from_annotation (
632
- next (a for a in annotation .__args__ if a != type (None )), module , rf
656
+ next (a for a in annotation .__args__ if a != type (None )), module , rf # pyright: ignore
633
657
)
634
658
635
659
def _deserialize_with_optional (if_obj_deserializer : typing .Optional [typing .Callable ], obj ):
@@ -642,7 +666,13 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla
642
666
pass
643
667
644
668
if getattr (annotation , "__origin__" , None ) is typing .Union :
645
- deserializers = [_get_deserialize_callable_from_annotation (arg , module , rf ) for arg in annotation .__args__ ]
669
+ # initial ordering is we make `string` the last deserialization option, because it is often them most generic
670
+ deserializers = [
671
+ _get_deserialize_callable_from_annotation (arg , module , rf )
672
+ for arg in sorted (
673
+ annotation .__args__ , key = lambda x : hasattr (x , "__name__" ) and x .__name__ == "str" # pyright: ignore
674
+ )
675
+ ]
646
676
647
677
def _deserialize_with_union (deserializers , obj ):
648
678
for deserializer in deserializers :
@@ -655,32 +685,31 @@ def _deserialize_with_union(deserializers, obj):
655
685
return functools .partial (_deserialize_with_union , deserializers )
656
686
657
687
try :
658
- if annotation ._name == "Dict" :
659
- key_deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [0 ], module , rf )
660
- value_deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [1 ], module , rf )
688
+ if annotation ._name == "Dict" : # pyright: ignore
689
+ value_deserializer = _get_deserialize_callable_from_annotation (
690
+ annotation .__args__ [1 ], module , rf # pyright: ignore
691
+ )
661
692
662
693
def _deserialize_dict (
663
- key_deserializer : typing .Optional [typing .Callable ],
664
694
value_deserializer : typing .Optional [typing .Callable ],
665
695
obj : typing .Dict [typing .Any , typing .Any ],
666
696
):
667
697
if obj is None :
668
698
return obj
669
699
return {
670
- _deserialize ( key_deserializer , k , module ) : _deserialize (value_deserializer , v , module )
700
+ k : _deserialize (value_deserializer , v , module )
671
701
for k , v in obj .items ()
672
702
}
673
703
674
704
return functools .partial (
675
705
_deserialize_dict ,
676
- key_deserializer ,
677
706
value_deserializer ,
678
707
)
679
708
except (AttributeError , IndexError ):
680
709
pass
681
710
try :
682
- if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]:
683
- if len (annotation .__args__ ) > 1 :
711
+ if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]: # pyright: ignore
712
+ if len (annotation .__args__ ) > 1 : # pyright: ignore
684
713
685
714
def _deserialize_multiple_sequence (
686
715
entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
@@ -694,10 +723,12 @@ def _deserialize_multiple_sequence(
694
723
)
695
724
696
725
entry_deserializers = [
697
- _get_deserialize_callable_from_annotation (dt , module , rf ) for dt in annotation .__args__
726
+ _get_deserialize_callable_from_annotation (dt , module , rf ) for dt in annotation .__args__ # pyright: ignore
698
727
]
699
728
return functools .partial (_deserialize_multiple_sequence , entry_deserializers )
700
- deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [0 ], module , rf )
729
+ deserializer = _get_deserialize_callable_from_annotation (
730
+ annotation .__args__ [0 ], module , rf # pyright: ignore
731
+ )
701
732
702
733
def _deserialize_sequence (
703
734
deserializer : typing .Optional [typing .Callable ],
@@ -712,27 +743,29 @@ def _deserialize_sequence(
712
743
pass
713
744
714
745
def _deserialize_default (
715
- annotation ,
716
- deserializer_from_mapping ,
746
+ deserializer ,
717
747
obj ,
718
748
):
719
749
if obj is None :
720
750
return obj
721
751
try :
722
- return _deserialize_with_callable (annotation , obj )
752
+ return _deserialize_with_callable (deserializer , obj )
723
753
except Exception :
724
754
pass
725
- return _deserialize_with_callable ( deserializer_from_mapping , obj )
755
+ return obj
726
756
727
- return functools .partial (_deserialize_default , annotation , get_deserializer (annotation , rf ))
757
+ if get_deserializer (annotation , rf ):
758
+ return functools .partial (_deserialize_default , get_deserializer (annotation , rf ))
759
+
760
+ return functools .partial (_deserialize_default , annotation )
728
761
729
762
730
763
def _deserialize_with_callable (
731
764
deserializer : typing .Optional [typing .Callable [[typing .Any ], typing .Any ]],
732
765
value : typing .Any ,
733
766
):
734
767
try :
735
- if value is None :
768
+ if value is None or isinstance ( value , _Null ) :
736
769
return None
737
770
if deserializer is None :
738
771
return value
@@ -760,7 +793,8 @@ def _deserialize(
760
793
value = value .http_response .json ()
761
794
if rf is None and format :
762
795
rf = _RestField (format = format )
763
- deserializer = _get_deserialize_callable_from_annotation (deserializer , module , rf )
796
+ if not isinstance (deserializer , functools .partial ):
797
+ deserializer = _get_deserialize_callable_from_annotation (deserializer , module , rf )
764
798
return _deserialize_with_callable (deserializer , value )
765
799
766
800
@@ -774,6 +808,7 @@ def __init__(
774
808
visibility : typing .Optional [typing .List [str ]] = None ,
775
809
default : typing .Any = _UNSET ,
776
810
format : typing .Optional [str ] = None ,
811
+ is_multipart_file_input : bool = False ,
777
812
):
778
813
self ._type = type
779
814
self ._rest_name_input = name
@@ -783,6 +818,11 @@ def __init__(
783
818
self ._is_model = False
784
819
self ._default = default
785
820
self ._format = format
821
+ self ._is_multipart_file_input = is_multipart_file_input
822
+
823
+ @property
824
+ def _class_type (self ) -> typing .Any :
825
+ return getattr (self ._type , "args" , [None ])[0 ]
786
826
787
827
@property
788
828
def _rest_name (self ) -> str :
@@ -828,8 +868,9 @@ def rest_field(
828
868
visibility : typing .Optional [typing .List [str ]] = None ,
829
869
default : typing .Any = _UNSET ,
830
870
format : typing .Optional [str ] = None ,
871
+ is_multipart_file_input : bool = False ,
831
872
) -> typing .Any :
832
- return _RestField (name = name , type = type , visibility = visibility , default = default , format = format )
873
+ return _RestField (name = name , type = type , visibility = visibility , default = default , format = format , is_multipart_file_input = is_multipart_file_input )
833
874
834
875
835
876
def rest_discriminator (
0 commit comments