5
5
# license information.
6
6
# --------------------------------------------------------------------------
7
7
# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except
8
- # pyright: reportGeneralTypeIssues=false
9
8
10
9
import calendar
11
10
import decimal
16
15
import re
17
16
import copy
18
17
import typing
19
- import email
18
+ import enum
19
+ import email .utils
20
20
from datetime import datetime , date , time , timedelta , timezone
21
21
from json import JSONEncoder
22
+ from typing_extensions import Self
22
23
import isodate
23
24
from azure .core .exceptions import DeserializationError
24
25
from azure .core import CaseInsensitiveEnumMeta
35
36
__all__ = ["SdkJSONEncoder" , "Model" , "rest_field" , "rest_discriminator" ]
36
37
37
38
TZ_UTC = timezone .utc
39
+ _T = typing .TypeVar ("_T" )
38
40
39
41
40
42
def _timedelta_as_isostr (td : timedelta ) -> str :
@@ -242,7 +244,7 @@ def _deserialize_date(attr: typing.Union[str, date]) -> date:
242
244
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
243
245
if isinstance (attr , date ):
244
246
return attr
245
- return isodate .parse_date (attr , defaultmonth = None , defaultday = None )
247
+ return isodate .parse_date (attr , defaultmonth = None , defaultday = None ) # type: ignore
246
248
247
249
248
250
def _deserialize_time (attr : typing .Union [str , time ]) -> time :
@@ -375,8 +377,12 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
375
377
except KeyError :
376
378
return default
377
379
378
- @typing .overload # type: ignore
379
- def pop (self , key : str ) -> typing .Any : # pylint: disable=no-member
380
+ @typing .overload
381
+ def pop (self , key : str ) -> typing .Any :
382
+ ...
383
+
384
+ @typing .overload
385
+ def pop (self , key : str , default : _T ) -> _T :
380
386
...
381
387
382
388
@typing .overload
@@ -397,8 +403,8 @@ def clear(self) -> None:
397
403
def update (self , * args : typing .Any , ** kwargs : typing .Any ) -> None :
398
404
self ._data .update (* args , ** kwargs )
399
405
400
- @typing .overload # type: ignore
401
- def setdefault (self , key : str ) -> typing . Any :
406
+ @typing .overload
407
+ def setdefault (self , key : str , default : None = None ) -> None :
402
408
...
403
409
404
410
@typing .overload
@@ -438,6 +444,8 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m
438
444
return _serialize_bytes (o , format )
439
445
if isinstance (o , decimal .Decimal ):
440
446
return float (o )
447
+ if isinstance (o , enum .Enum ):
448
+ return o .value
441
449
try :
442
450
# First try datetime.datetime
443
451
return _serialize_datetime (o , format )
@@ -462,7 +470,13 @@ def _get_rest_field(
462
470
463
471
464
472
def _create_value (rf : typing .Optional ["_RestField" ], value : typing .Any ) -> typing .Any :
465
- return _deserialize (rf ._type , value ) if (rf and rf ._is_model ) else _serialize (value , rf ._format if rf else None )
473
+ if not rf :
474
+ return _serialize (value , None )
475
+ if rf ._is_multipart_file_input :
476
+ return value
477
+ if rf ._is_model :
478
+ return _deserialize (rf ._type , value )
479
+ return _serialize (value , rf ._format )
466
480
467
481
468
482
class Model (_MyMutableMapping ):
@@ -498,7 +512,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
498
512
def copy (self ) -> "Model" :
499
513
return Model (self .__dict__ )
500
514
501
- def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> "Model" : # pylint: disable=unused-argument
515
+ def __new__ (cls , * args : typing .Any , ** kwargs : typing .Any ) -> Self : # pylint: disable=unused-argument
502
516
# we know the last three classes in mro are going to be 'Model', 'dict', and 'object'
503
517
mros = cls .__mro__ [:- 3 ][::- 1 ] # ignore model, dict, and object parents, and reverse the mro order
504
518
attr_to_rest_field : typing .Dict [str , _RestField ] = { # map attribute name to rest_field property
@@ -540,7 +554,7 @@ def _deserialize(cls, data, exist_discriminators):
540
554
return cls (data )
541
555
discriminator = cls ._get_discriminator (exist_discriminators )
542
556
exist_discriminators .append (discriminator )
543
- mapped_cls = cls .__mapping__ .get (data .get (discriminator ), cls ) # pylint: disable=no-member
557
+ mapped_cls = cls .__mapping__ .get (data .get (discriminator ), cls ) # pyright: ignore # pylint: disable=no-member
544
558
if mapped_cls == cls :
545
559
return cls (data )
546
560
return mapped_cls ._deserialize (data , exist_discriminators ) # pylint: disable=protected-access
@@ -557,17 +571,24 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
557
571
if exclude_readonly :
558
572
readonly_props = [p ._rest_name for p in self ._attr_to_rest_field .values () if _is_readonly (p )]
559
573
for k , v in self .items ():
560
- if exclude_readonly and k in readonly_props : # pyright: ignore[reportUnboundVariable]
574
+ if exclude_readonly and k in readonly_props : # pyright: ignore
561
575
continue
562
- result [k ] = Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
576
+ is_multipart_file_input = False
577
+ try :
578
+ is_multipart_file_input = next (
579
+ rf for rf in self ._attr_to_rest_field .values () if rf ._rest_name == k
580
+ )._is_multipart_file_input
581
+ except StopIteration :
582
+ pass
583
+ result [k ] = v if is_multipart_file_input else Model ._as_dict_value (v , exclude_readonly = exclude_readonly )
563
584
return result
564
585
565
586
@staticmethod
566
587
def _as_dict_value (v : typing .Any , exclude_readonly : bool = False ) -> typing .Any :
567
588
if v is None or isinstance (v , _Null ):
568
589
return None
569
590
if isinstance (v , (list , tuple , set )):
570
- return [ Model ._as_dict_value (x , exclude_readonly = exclude_readonly ) for x in v ]
591
+ return type ( v )( Model ._as_dict_value (x , exclude_readonly = exclude_readonly ) for x in v )
571
592
if isinstance (v , dict ):
572
593
return {dk : Model ._as_dict_value (dv , exclude_readonly = exclude_readonly ) for dk , dv in v .items ()}
573
594
return v .as_dict (exclude_readonly = exclude_readonly ) if hasattr (v , "as_dict" ) else v
@@ -605,29 +626,22 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj
605
626
return obj
606
627
return _deserialize (model_deserializer , obj )
607
628
608
- return functools .partial (_deserialize_model , annotation )
629
+ return functools .partial (_deserialize_model , annotation ) # pyright: ignore
609
630
except Exception :
610
631
pass
611
632
612
633
# is it a literal?
613
634
try :
614
- if sys .version_info >= (3 , 8 ):
615
- from typing import (
616
- Literal ,
617
- ) # pylint: disable=no-name-in-module, ungrouped-imports
618
- else :
619
- from typing_extensions import Literal # type: ignore # pylint: disable=ungrouped-imports
620
-
621
- if annotation .__origin__ == Literal :
635
+ if annotation .__origin__ is typing .Literal : # pyright: ignore
622
636
return None
623
637
except AttributeError :
624
638
pass
625
639
626
640
# is it optional?
627
641
try :
628
- if any (a for a in annotation .__args__ if a == type (None )):
642
+ if any (a for a in annotation .__args__ if a == type (None )): # pyright: ignore
629
643
if_obj_deserializer = _get_deserialize_callable_from_annotation (
630
- next (a for a in annotation .__args__ if a != type (None )), module , rf
644
+ next (a for a in annotation .__args__ if a != type (None )), module , rf # pyright: ignore
631
645
)
632
646
633
647
def _deserialize_with_optional (if_obj_deserializer : typing .Optional [typing .Callable ], obj ):
@@ -640,7 +654,13 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla
640
654
pass
641
655
642
656
if getattr (annotation , "__origin__" , None ) is typing .Union :
643
- deserializers = [_get_deserialize_callable_from_annotation (arg , module , rf ) for arg in annotation .__args__ ]
657
+ # initial ordering is we make `string` the last deserialization option, because it is often them most generic
658
+ deserializers = [
659
+ _get_deserialize_callable_from_annotation (arg , module , rf )
660
+ for arg in sorted (
661
+ annotation .__args__ , key = lambda x : hasattr (x , "__name__" ) and x .__name__ == "str" # pyright: ignore
662
+ )
663
+ ]
644
664
645
665
def _deserialize_with_union (deserializers , obj ):
646
666
for deserializer in deserializers :
@@ -653,8 +673,10 @@ def _deserialize_with_union(deserializers, obj):
653
673
return functools .partial (_deserialize_with_union , deserializers )
654
674
655
675
try :
656
- if annotation ._name == "Dict" :
657
- value_deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [1 ], module , rf )
676
+ if annotation ._name == "Dict" : # pyright: ignore
677
+ value_deserializer = _get_deserialize_callable_from_annotation (
678
+ annotation .__args__ [1 ], module , rf # pyright: ignore
679
+ )
658
680
659
681
def _deserialize_dict (
660
682
value_deserializer : typing .Optional [typing .Callable ],
@@ -671,8 +693,8 @@ def _deserialize_dict(
671
693
except (AttributeError , IndexError ):
672
694
pass
673
695
try :
674
- if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]:
675
- if len (annotation .__args__ ) > 1 :
696
+ if annotation ._name in ["List" , "Set" , "Tuple" , "Sequence" ]: # pyright: ignore
697
+ if len (annotation .__args__ ) > 1 : # pyright: ignore
676
698
677
699
def _deserialize_multiple_sequence (
678
700
entry_deserializers : typing .List [typing .Optional [typing .Callable ]],
@@ -686,10 +708,13 @@ def _deserialize_multiple_sequence(
686
708
)
687
709
688
710
entry_deserializers = [
689
- _get_deserialize_callable_from_annotation (dt , module , rf ) for dt in annotation .__args__
711
+ _get_deserialize_callable_from_annotation (dt , module , rf )
712
+ for dt in annotation .__args__ # pyright: ignore
690
713
]
691
714
return functools .partial (_deserialize_multiple_sequence , entry_deserializers )
692
- deserializer = _get_deserialize_callable_from_annotation (annotation .__args__ [0 ], module , rf )
715
+ deserializer = _get_deserialize_callable_from_annotation (
716
+ annotation .__args__ [0 ], module , rf # pyright: ignore
717
+ )
693
718
694
719
def _deserialize_sequence (
695
720
deserializer : typing .Optional [typing .Callable ],
@@ -769,6 +794,7 @@ def __init__(
769
794
visibility : typing .Optional [typing .List [str ]] = None ,
770
795
default : typing .Any = _UNSET ,
771
796
format : typing .Optional [str ] = None ,
797
+ is_multipart_file_input : bool = False ,
772
798
):
773
799
self ._type = type
774
800
self ._rest_name_input = name
@@ -778,6 +804,11 @@ def __init__(
778
804
self ._is_model = False
779
805
self ._default = default
780
806
self ._format = format
807
+ self ._is_multipart_file_input = is_multipart_file_input
808
+
809
+ @property
810
+ def _class_type (self ) -> typing .Any :
811
+ return getattr (self ._type , "args" , [None ])[0 ]
781
812
782
813
@property
783
814
def _rest_name (self ) -> str :
@@ -823,8 +854,16 @@ def rest_field(
823
854
visibility : typing .Optional [typing .List [str ]] = None ,
824
855
default : typing .Any = _UNSET ,
825
856
format : typing .Optional [str ] = None ,
857
+ is_multipart_file_input : bool = False ,
826
858
) -> typing .Any :
827
- return _RestField (name = name , type = type , visibility = visibility , default = default , format = format )
859
+ return _RestField (
860
+ name = name ,
861
+ type = type ,
862
+ visibility = visibility ,
863
+ default = default ,
864
+ format = format ,
865
+ is_multipart_file_input = is_multipart_file_input ,
866
+ )
828
867
829
868
830
869
def rest_discriminator (
0 commit comments