92
92
WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE , TYPE_MAP ]
93
93
94
94
95
+ def get_default (proto_type : int ) -> Any :
96
+ """Get the default (zero value) for a given type."""
97
+ return {
98
+ TYPE_BOOL : False ,
99
+ TYPE_FLOAT : 0.0 ,
100
+ TYPE_DOUBLE : 0.0 ,
101
+ TYPE_STRING : "" ,
102
+ TYPE_BYTES : b"" ,
103
+ TYPE_MAP : {},
104
+ }.get (proto_type , 0 )
105
+
106
+
95
107
@dataclasses .dataclass (frozen = True )
96
108
class FieldMetadata :
97
109
"""Stores internal metadata used for parsing & serialization."""
@@ -114,7 +126,7 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
114
126
def dataclass_field (
115
127
number : int ,
116
128
proto_type : str ,
117
- default : Any ,
129
+ default : Any = None ,
118
130
map_types : Optional [Tuple [str , str ]] = None ,
119
131
** kwargs : dict ,
120
132
) -> dataclasses .Field :
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
141
153
return dataclass_field (number , TYPE_ENUM , default = default )
142
154
143
155
156
+ def bool_field (number : int , default : Union [bool , Type [Iterable ]] = 0 ) -> Any :
157
+ return dataclass_field (number , TYPE_BOOL , default = default )
158
+
159
+
144
160
def int32_field (number : int , default : Union [int , Type [Iterable ]] = 0 ) -> Any :
145
161
return dataclass_field (number , TYPE_INT32 , default = default )
146
162
@@ -193,8 +209,12 @@ def string_field(number: int, default: str = "") -> Any:
193
209
return dataclass_field (number , TYPE_STRING , default = default )
194
210
195
211
196
- def message_field (number : int , default : Type ["Message" ]) -> Any :
197
- return dataclass_field (number , TYPE_MESSAGE , default = default )
212
+ def bytes_field (number : int , default : bytes = b"" ) -> Any :
213
+ return dataclass_field (number , TYPE_BYTES , default = default )
214
+
215
+
216
+ def message_field (number : int ) -> Any :
217
+ return dataclass_field (number , TYPE_MESSAGE )
198
218
199
219
200
220
def map_field (number : int , key_type : str , value_type : str ) -> Any :
@@ -345,6 +365,29 @@ class Message(ABC):
345
365
to go between Python, binary and JSON protobuf message representations.
346
366
"""
347
367
368
+ def __post_init__ (self ) -> None :
369
+ # Set a default value for each field in the class after `__init__` has
370
+ # already been run.
371
+ for field in dataclasses .fields (self ):
372
+ meta = FieldMetadata .get (field )
373
+
374
+ t = self ._cls_for (field , index = - 1 )
375
+
376
+ value = 0
377
+ if meta .proto_type == TYPE_MAP :
378
+ # Maps cannot be repeated, so we check these first.
379
+ value = {}
380
+ elif hasattr (t , "__args__" ) and len (t .__args__ ) == 1 :
381
+ # Anything else with type args is a list.
382
+ value = []
383
+ elif meta .proto_type == TYPE_MESSAGE :
384
+ # Message means creating an instance of the right type.
385
+ value = t ()
386
+ else :
387
+ value = get_default (meta .proto_type )
388
+
389
+ setattr (self , field .name , value )
390
+
348
391
def __bytes__ (self ) -> bytes :
349
392
"""
350
393
Get the binary encoded Protobuf representation of this instance.
@@ -356,6 +399,7 @@ def __bytes__(self) -> bytes:
356
399
357
400
if isinstance (value , list ):
358
401
if not len (value ):
402
+ # Empty values are not serialized
359
403
continue
360
404
361
405
if meta .proto_type in PACKED_TYPES :
@@ -371,14 +415,16 @@ def __bytes__(self) -> bytes:
371
415
output += _serialize_single (meta .number , meta .proto_type , item )
372
416
elif isinstance (value , dict ):
373
417
if not len (value ):
418
+ # Empty values are not serialized
374
419
continue
375
420
376
421
for k , v in value .items ():
377
422
sk = _serialize_single (1 , meta .map_types [0 ], k )
378
423
sv = _serialize_single (2 , meta .map_types [1 ], v )
379
424
output += _serialize_single (meta .number , meta .proto_type , sk + sv )
380
425
else :
381
- if value == field .default :
426
+ if value == get_default (meta .proto_type ):
427
+ # Default (zero) values are not serialized
382
428
continue
383
429
384
430
output += _serialize_single (meta .number , meta .proto_type , value )
@@ -390,7 +436,7 @@ def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
390
436
module = inspect .getmodule (self )
391
437
type_hints = get_type_hints (self , vars (module ))
392
438
cls = type_hints [field .name ]
393
- if hasattr (cls , "__args__" ):
439
+ if hasattr (cls , "__args__" ) and index >= 0 :
394
440
cls = type_hints [field .name ].__args__ [index ]
395
441
return cls
396
442
@@ -522,7 +568,7 @@ def from_dict(self, value: dict) -> T:
522
568
"""
523
569
for field in dataclasses .fields (self ):
524
570
meta = FieldMetadata .get (field )
525
- if field .name in value :
571
+ if field .name in value and value [ field . name ] is not None :
526
572
if meta .proto_type == "message" :
527
573
v = getattr (self , field .name )
528
574
# print(v, value[field.name])
0 commit comments