92
92
WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE , TYPE_MAP ]
93
93
94
94
95
+ def get_default (proto_type : int ) -> Any :
96
+ """Get the default (zero value) for a given type."""
97
+ return {
98
+ TYPE_BOOL : False ,
99
+ TYPE_FLOAT : 0.0 ,
100
+ TYPE_DOUBLE : 0.0 ,
101
+ TYPE_STRING : "" ,
102
+ TYPE_BYTES : b"" ,
103
+ TYPE_MAP : {},
104
+ }.get (proto_type , 0 )
105
+
106
+
95
107
@dataclasses .dataclass (frozen = True )
96
108
class FieldMetadata :
97
109
"""Stores internal metadata used for parsing & serialization."""
@@ -114,7 +126,7 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
114
126
def dataclass_field (
115
127
number : int ,
116
128
proto_type : str ,
117
- default : Any ,
129
+ default : Any = None ,
118
130
map_types : Optional [Tuple [str , str ]] = None ,
119
131
** kwargs : dict ,
120
132
) -> dataclasses .Field :
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
141
153
return dataclass_field (number , TYPE_ENUM , default = default )
142
154
143
155
156
+ def bool_field (number : int , default : Union [bool , Type [Iterable ]] = 0 ) -> Any :
157
+ return dataclass_field (number , TYPE_BOOL , default = default )
158
+
159
+
144
160
def int32_field (number : int , default : Union [int , Type [Iterable ]] = 0 ) -> Any :
145
161
return dataclass_field (number , TYPE_INT32 , default = default )
146
162
@@ -193,8 +209,8 @@ def string_field(number: int, default: str = "") -> Any:
193
209
return dataclass_field (number , TYPE_STRING , default = default )
194
210
195
211
196
- def message_field (number : int , default : Type [ "Message" ] ) -> Any :
197
- return dataclass_field (number , TYPE_MESSAGE , default = default )
212
+ def message_field (number : int ) -> Any :
213
+ return dataclass_field (number , TYPE_MESSAGE )
198
214
199
215
200
216
def map_field (number : int , key_type : str , value_type : str ) -> Any :
@@ -345,6 +361,29 @@ class Message(ABC):
345
361
to go between Python, binary and JSON protobuf message representations.
346
362
"""
347
363
364
+ def __post_init__ (self ) -> None :
365
+ # Set a default value for each field in the class after `__init__` has
366
+ # already been run.
367
+ for field in dataclasses .fields (self ):
368
+ meta = FieldMetadata .get (field )
369
+
370
+ t = self ._cls_for (field , index = - 1 )
371
+
372
+ value = 0
373
+ if meta .proto_type == TYPE_MAP :
374
+ # Maps cannot be repeated, so we check these first.
375
+ value = {}
376
+ elif hasattr (t , "__args__" ) and len (t .__args__ ) == 1 :
377
+ # Anything else with type args is a list.
378
+ value = []
379
+ elif meta .proto_type == TYPE_MESSAGE :
380
+ # Message means creating an instance of the right type.
381
+ value = t ()
382
+ else :
383
+ value = get_default (meta .proto_type )
384
+
385
+ setattr (self , field .name , value )
386
+
348
387
def __bytes__ (self ) -> bytes :
349
388
"""
350
389
Get the binary encoded Protobuf representation of this instance.
@@ -356,6 +395,7 @@ def __bytes__(self) -> bytes:
356
395
357
396
if isinstance (value , list ):
358
397
if not len (value ):
398
+ # Empty values are not serialized
359
399
continue
360
400
361
401
if meta .proto_type in PACKED_TYPES :
@@ -371,14 +411,16 @@ def __bytes__(self) -> bytes:
371
411
output += _serialize_single (meta .number , meta .proto_type , item )
372
412
elif isinstance (value , dict ):
373
413
if not len (value ):
414
+ # Empty values are not serialized
374
415
continue
375
416
376
417
for k , v in value .items ():
377
418
sk = _serialize_single (1 , meta .map_types [0 ], k )
378
419
sv = _serialize_single (2 , meta .map_types [1 ], v )
379
420
output += _serialize_single (meta .number , meta .proto_type , sk + sv )
380
421
else :
381
- if value == field .default :
422
+ if value == get_default (meta .proto_type ):
423
+ # Default (zero) values are not serialized
382
424
continue
383
425
384
426
output += _serialize_single (meta .number , meta .proto_type , value )
@@ -390,7 +432,7 @@ def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
390
432
module = inspect .getmodule (self )
391
433
type_hints = get_type_hints (self , vars (module ))
392
434
cls = type_hints [field .name ]
393
- if hasattr (cls , "__args__" ):
435
+ if hasattr (cls , "__args__" ) and index >= 0 :
394
436
cls = type_hints [field .name ].__args__ [index ]
395
437
return cls
396
438
@@ -522,7 +564,7 @@ def from_dict(self, value: dict) -> T:
522
564
"""
523
565
for field in dataclasses .fields (self ):
524
566
meta = FieldMetadata .get (field )
525
- if field .name in value :
567
+ if field .name in value and value [ field . name ] is not None :
526
568
if meta .proto_type == "message" :
527
569
v = getattr (self , field .name )
528
570
# print(v, value[field.name])
0 commit comments