24
24
25
25
import grpclib .client
26
26
import grpclib .const
27
+ import stringcase
27
28
28
29
# Proto 3 data types
29
30
TYPE_ENUM = "enum"
101
102
WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE , TYPE_MAP ]
102
103
103
104
105
+ class Casing (enum .Enum ):
106
+ """Casing constants for serialization."""
107
+
108
+ CAMEL = stringcase .camelcase
109
+ SNAKE = stringcase .snakecase
110
+
111
+
104
112
class _PLACEHOLDER :
105
113
pass
106
114
@@ -624,48 +632,50 @@ def parse(self: T, data: bytes) -> T:
624
632
def FromString (cls : Type [T ], data : bytes ) -> T :
625
633
return cls ().parse (data )
626
634
627
- def to_dict (self ) -> dict :
635
+ def to_dict (self , casing : Casing = Casing . CAMEL ) -> dict :
628
636
"""
629
637
Returns a dict representation of this message instance which can be
630
- used to serialize to e.g. JSON.
638
+ used to serialize to e.g. JSON. Defaults to camel casing for
639
+ compatibility but can be set to other modes.
631
640
"""
632
641
output : Dict [str , Any ] = {}
633
642
for field in dataclasses .fields (self ):
634
643
meta = FieldMetadata .get (field )
635
644
v = getattr (self , field .name )
645
+ cased_name = casing (field .name )
636
646
if meta .proto_type == "message" :
637
647
if isinstance (v , list ):
638
648
# Convert each item.
639
649
v = [i .to_dict () for i in v ]
640
- output [field . name ] = v
650
+ output [cased_name ] = v
641
651
elif v ._serialized_on_wire :
642
- output [field . name ] = v .to_dict ()
652
+ output [cased_name ] = v .to_dict ()
643
653
elif meta .proto_type == "map" :
644
654
for k in v :
645
655
if hasattr (v [k ], "to_dict" ):
646
656
v [k ] = v [k ].to_dict ()
647
657
648
658
if v :
649
- output [field . name ] = v
659
+ output [cased_name ] = v
650
660
elif v != get_default (meta .proto_type ):
651
661
if meta .proto_type in INT_64_TYPES :
652
662
if isinstance (v , list ):
653
- output [field . name ] = [str (n ) for n in v ]
663
+ output [cased_name ] = [str (n ) for n in v ]
654
664
else :
655
- output [field . name ] = str (v )
665
+ output [cased_name ] = str (v )
656
666
elif meta .proto_type == TYPE_BYTES :
657
667
if isinstance (v , list ):
658
- output [field . name ] = [b64encode (b ).decode ("utf8" ) for b in v ]
668
+ output [cased_name ] = [b64encode (b ).decode ("utf8" ) for b in v ]
659
669
else :
660
- output [field . name ] = b64encode (v ).decode ("utf8" )
670
+ output [cased_name ] = b64encode (v ).decode ("utf8" )
661
671
elif meta .proto_type == TYPE_ENUM :
662
672
enum_values = list (self ._cls_for (field ))
663
673
if isinstance (v , list ):
664
- output [field . name ] = [enum_values [e ].name for e in v ]
674
+ output [cased_name ] = [enum_values [e ].name for e in v ]
665
675
else :
666
- output [field . name ] = enum_values [v ].name
676
+ output [cased_name ] = enum_values [v ].name
667
677
else :
668
- output [field . name ] = v
678
+ output [cased_name ] = v
669
679
return output
670
680
671
681
def from_dict (self : T , value : dict ) -> T :
@@ -674,44 +684,49 @@ def from_dict(self: T, value: dict) -> T:
674
684
returns the instance itself and is therefore assignable and chainable.
675
685
"""
676
686
self ._serialized_on_wire = True
677
- for field in dataclasses .fields (self ):
678
- meta = FieldMetadata .get (field )
679
- if field .name in value and value [field .name ] is not None :
680
- if meta .proto_type == "message" :
681
- v = getattr (self , field .name )
682
- # print(v, value[field.name])
683
- if isinstance (v , list ):
684
- cls = self ._cls_for (field )
685
- for i in range (len (value [field .name ])):
686
- v .append (cls ().from_dict (value [field .name ][i ]))
687
- else :
688
- v .from_dict (value [field .name ])
689
- elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
690
- v = getattr (self , field .name )
691
- cls = self ._cls_for (field , index = 1 )
692
- for k in value [field .name ]:
693
- v [k ] = cls ().from_dict (value [field .name ][k ])
694
- else :
695
- v = value [field .name ]
696
- if meta .proto_type in INT_64_TYPES :
697
- if isinstance (value [field .name ], list ):
698
- v = [int (n ) for n in value [field .name ]]
699
- else :
700
- v = int (value [field .name ])
701
- elif meta .proto_type == TYPE_BYTES :
702
- if isinstance (value [field .name ], list ):
703
- v = [b64decode (n ) for n in value [field .name ]]
704
- else :
705
- v = b64decode (value [field .name ])
706
- elif meta .proto_type == TYPE_ENUM :
707
- enum_cls = self ._cls_for (field )
708
- if isinstance (v , list ):
709
- v = [enum_cls .from_string (e ) for e in v ]
710
- elif isinstance (v , str ):
711
- v = enum_cls .from_string (v )
687
+ fields_by_name = {f .name : f for f in dataclasses .fields (self )}
688
+ for key in value :
689
+ snake_cased = stringcase .snakecase (key )
690
+ if snake_cased in fields_by_name :
691
+ field = fields_by_name [snake_cased ]
692
+ meta = FieldMetadata .get (field )
712
693
713
- if v is not None :
714
- setattr (self , field .name , v )
694
+ if value [key ] is not None :
695
+ if meta .proto_type == "message" :
696
+ v = getattr (self , field .name )
697
+ # print(v, value[key])
698
+ if isinstance (v , list ):
699
+ cls = self ._cls_for (field )
700
+ for i in range (len (value [key ])):
701
+ v .append (cls ().from_dict (value [key ][i ]))
702
+ else :
703
+ v .from_dict (value [key ])
704
+ elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
705
+ v = getattr (self , field .name )
706
+ cls = self ._cls_for (field , index = 1 )
707
+ for k in value [key ]:
708
+ v [k ] = cls ().from_dict (value [key ][k ])
709
+ else :
710
+ v = value [key ]
711
+ if meta .proto_type in INT_64_TYPES :
712
+ if isinstance (value [key ], list ):
713
+ v = [int (n ) for n in value [key ]]
714
+ else :
715
+ v = int (value [key ])
716
+ elif meta .proto_type == TYPE_BYTES :
717
+ if isinstance (value [key ], list ):
718
+ v = [b64decode (n ) for n in value [key ]]
719
+ else :
720
+ v = b64decode (value [key ])
721
+ elif meta .proto_type == TYPE_ENUM :
722
+ enum_cls = self ._cls_for (field )
723
+ if isinstance (v , list ):
724
+ v = [enum_cls .from_string (e ) for e in v ]
725
+ elif isinstance (v , str ):
726
+ v = enum_cls .from_string (v )
727
+
728
+ if v is not None :
729
+ setattr (self , field .name , v )
715
730
return self
716
731
717
732
def to_json (self , indent : Union [None , int , str ] = None ) -> str :
0 commit comments