@@ -118,18 +118,6 @@ class _PLACEHOLDER:
118
118
PLACEHOLDER : Any = _PLACEHOLDER ()
119
119
120
120
121
- def get_default (proto_type : str ) -> Any :
122
- """Get the default (zero value) for a given type."""
123
- return {
124
- TYPE_BOOL : False ,
125
- TYPE_FLOAT : 0.0 ,
126
- TYPE_DOUBLE : 0.0 ,
127
- TYPE_STRING : "" ,
128
- TYPE_BYTES : b"" ,
129
- TYPE_MAP : {},
130
- }.get (proto_type , 0 )
131
-
132
-
133
121
@dataclasses .dataclass (frozen = True )
134
122
class FieldMetadata :
135
123
"""Stores internal metadata used for parsing & serialization."""
@@ -467,11 +455,22 @@ def __bytes__(self) -> bytes:
467
455
if meta .group and self ._group_map ["groups" ][meta .group ]["current" ] == field :
468
456
selected_in_group = True
469
457
470
- if isinstance (value , list ):
471
- if not len (value ) and not selected_in_group :
472
- # Empty values are not serialized
473
- continue
458
+ serialize_empty = False
459
+ if isinstance (value , Message ) and value ._serialized_on_wire :
460
+ # Empty messages can still be sent on the wire if they were
461
+ # set (or received empty).
462
+ serialize_empty = True
463
+
464
+ if value == self ._get_field_default (field , meta ) and not (
465
+ selected_in_group or serialize_empty
466
+ ):
467
+ # Default (zero) values are not serialized. Two exceptions are
468
+ # if this is the selected oneof item or if we know we have to
469
+ # serialize an empty message (i.e. zero value was explicitly
470
+ # set by the user).
471
+ continue
474
472
473
+ if isinstance (value , list ):
475
474
if meta .proto_type in PACKED_TYPES :
476
475
# Packed lists look like a length-delimited field. First,
477
476
# preprocess/encode each value into a buffer and then
@@ -484,23 +483,12 @@ def __bytes__(self) -> bytes:
484
483
for item in value :
485
484
output += _serialize_single (meta .number , meta .proto_type , item )
486
485
elif isinstance (value , dict ):
487
- if not len (value ) and not selected_in_group :
488
- # Empty values are not serialized
489
- continue
490
-
491
486
for k , v in value .items ():
492
487
assert meta .map_types
493
488
sk = _serialize_single (1 , meta .map_types [0 ], k )
494
489
sv = _serialize_single (2 , meta .map_types [1 ], v )
495
490
output += _serialize_single (meta .number , meta .proto_type , sk + sv )
496
491
else :
497
- if value == get_default (meta .proto_type ) and not selected_in_group :
498
- # Default (zero) values are not serialized
499
- continue
500
-
501
- serialize_empty = False
502
- if isinstance (value , Message ) and value ._serialized_on_wire :
503
- serialize_empty = True
504
492
output += _serialize_single (
505
493
meta .number , meta .proto_type , value , serialize_empty = serialize_empty
506
494
)
@@ -510,30 +498,42 @@ def __bytes__(self) -> bytes:
510
498
# For compatibility with other libraries
511
499
SerializeToString = __bytes__
512
500
513
- def _cls_for (self , field : dataclasses .Field , index : int = 0 ) -> Type :
514
- """Get the message class for a field from the type hints."""
501
+ def _type_hint (self , field_name : str ) -> Type :
515
502
module = inspect .getmodule (self .__class__ )
516
503
type_hints = get_type_hints (self .__class__ , vars (module ))
517
- cls = type_hints [field .name ]
504
+ return type_hints [field_name ]
505
+
506
+ def _cls_for (self , field : dataclasses .Field , index : int = 0 ) -> Type :
507
+ """Get the message class for a field from the type hints."""
508
+ cls = self ._type_hint (field .name )
518
509
if hasattr (cls , "__args__" ) and index >= 0 :
519
- cls = type_hints [ field . name ] .__args__ [index ]
510
+ cls = cls .__args__ [index ]
520
511
return cls
521
512
522
513
def _get_field_default (self , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
523
- t = self ._cls_for (field , index = - 1 )
514
+ t = self ._type_hint (field . name )
524
515
525
516
value : Any = 0
526
- if meta .proto_type == TYPE_MAP :
527
- # Maps cannot be repeated, so we check these first.
528
- value = {}
529
- elif hasattr (t , "__args__" ) and len (t .__args__ ) == 1 :
530
- # Anything else with type args is a list.
531
- value = []
532
- elif meta .proto_type == TYPE_MESSAGE :
533
- # Message means creating an instance of the right type.
534
- value = t ()
517
+ if hasattr (t , "__origin__" ):
518
+ if t .__origin__ == dict :
519
+ # This is some kind of map (dict in Python).
520
+ value = {}
521
+ elif t .__origin__ == list :
522
+ # This is some kind of list (repeated) field.
523
+ value = []
524
+ elif t .__origin__ == Union and t .__args__ [1 ] == type (None ):
525
+ # This is an optional (wrapped) field. For setting the default we
526
+ # really don't care what kind of field it is.
527
+ value = None
528
+ else :
529
+ value = t ()
530
+ elif issubclass (t , Enum ):
531
+ # Enums always default to zero.
532
+ value = 0
535
533
else :
536
- value = get_default (meta .proto_type )
534
+ # This is either a primitive scalar or another message type. Calling
535
+ # it should result in its zero value.
536
+ value = t ()
537
537
538
538
return value
539
539
@@ -659,7 +659,7 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
659
659
660
660
if v :
661
661
output [cased_name ] = v
662
- elif v != get_default ( meta . proto_type ):
662
+ elif v != self . _get_field_default ( field , meta ):
663
663
if meta .proto_type in INT_64_TYPES :
664
664
if isinstance (v , list ):
665
665
output [cased_name ] = [str (n ) for n in v ]
0 commit comments