105
105
WIRE_LEN_DELIM_TYPES = [TYPE_STRING , TYPE_BYTES , TYPE_MESSAGE , TYPE_MAP ]
106
106
107
107
108
+ # Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
109
+ DATETIME_ZERO = datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
110
+
111
+
108
112
class Casing (enum .Enum ):
109
113
"""Casing constants for serialization."""
110
114
@@ -128,9 +132,11 @@ class FieldMetadata:
128
132
# Protobuf type name
129
133
proto_type : str
130
134
# Map information if the proto_type is a map
131
- map_types : Optional [Tuple [str , str ]]
135
+ map_types : Optional [Tuple [str , str ]] = None
132
136
# Groups several "one-of" fields together
133
- group : Optional [str ]
137
+ group : Optional [str ] = None
138
+ # Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
139
+ wraps : Optional [str ] = None
134
140
135
141
@staticmethod
136
142
def get (field : dataclasses .Field ) -> "FieldMetadata" :
@@ -144,11 +150,14 @@ def dataclass_field(
144
150
* ,
145
151
map_types : Optional [Tuple [str , str ]] = None ,
146
152
group : Optional [str ] = None ,
153
+ wraps : Optional [str ] = None ,
147
154
) -> dataclasses .Field :
148
155
"""Creates a dataclass field with attached protobuf metadata."""
149
156
return dataclasses .field (
150
157
default = PLACEHOLDER ,
151
- metadata = {"betterproto" : FieldMetadata (number , proto_type , map_types , group )},
158
+ metadata = {
159
+ "betterproto" : FieldMetadata (number , proto_type , map_types , group , wraps )
160
+ },
152
161
)
153
162
154
163
@@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
221
230
return dataclass_field (number , TYPE_BYTES , group = group )
222
231
223
232
224
- def message_field (number : int , group : Optional [str ] = None ) -> Any :
225
- return dataclass_field (number , TYPE_MESSAGE , group = group )
233
+ def message_field (
234
+ number : int , group : Optional [str ] = None , wraps : Optional [str ] = None
235
+ ) -> Any :
236
+ return dataclass_field (number , TYPE_MESSAGE , group = group , wraps = wraps )
226
237
227
238
228
239
def map_field (
@@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
273
284
return bytes (b + [bits ])
274
285
275
286
276
- def _preprocess_single (proto_type : str , value : Any ) -> bytes :
287
+ def _preprocess_single (proto_type : str , wraps : str , value : Any ) -> bytes :
277
288
"""Adjusts values before serialization."""
278
289
if proto_type in [
279
290
TYPE_ENUM ,
@@ -307,17 +318,26 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
307
318
seconds = int (total_ms / 1e6 )
308
319
nanos = int ((total_ms % 1e6 ) * 1e3 )
309
320
value = _Duration (seconds = seconds , nanos = nanos )
321
+ elif wraps :
322
+ if value is None :
323
+ return b""
324
+ value = _get_wrapper (wraps )(value = value )
310
325
311
326
return bytes (value )
312
327
313
328
return value
314
329
315
330
316
331
def _serialize_single (
317
- field_number : int , proto_type : str , value : Any , * , serialize_empty : bool = False
332
+ field_number : int ,
333
+ proto_type : str ,
334
+ value : Any ,
335
+ * ,
336
+ serialize_empty : bool = False ,
337
+ wraps : str = "" ,
318
338
) -> bytes :
319
339
"""Serializes a single field and value."""
320
- value = _preprocess_single (proto_type , value )
340
+ value = _preprocess_single (proto_type , wraps , value )
321
341
322
342
output = b""
323
343
if proto_type in WIRE_VARINT_TYPES :
@@ -330,7 +350,7 @@ def _serialize_single(
330
350
key = encode_varint ((field_number << 3 ) | 1 )
331
351
output += key + value
332
352
elif proto_type in WIRE_LEN_DELIM_TYPES :
333
- if len (value ) or serialize_empty :
353
+ if len (value ) or serialize_empty or wraps :
334
354
key = encode_varint ((field_number << 3 ) | 2 )
335
355
output += key + encode_varint (len (value )) + value
336
356
else :
@@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
370
390
while i < len (value ):
371
391
start = i
372
392
num_wire , i = decode_varint (value , i )
373
- # print(num_wire, i)
374
393
number = num_wire >> 3
375
394
wire_type = num_wire & 0x7
376
395
@@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
386
405
elif wire_type == 5 :
387
406
decoded , i = value [i : i + 4 ], i + 4
388
407
389
- # print(ParsedField(number=number, wire_type=wire_type, value=decoded))
390
-
391
408
yield ParsedField (
392
409
number = number , wire_type = wire_type , value = decoded , raw = value [start :i ]
393
410
)
@@ -462,6 +479,11 @@ def __bytes__(self) -> bytes:
462
479
meta = FieldMetadata .get (field )
463
480
value = getattr (self , field .name )
464
481
482
+ if value is None :
483
+ # Optional items should be skipped. This is used for the Google
484
+ # wrapper types.
485
+ continue
486
+
465
487
# Being selected in a a group means this field is the one that is
466
488
# currently set in a `oneof` group, so it must be serialized even
467
489
# if the value is the default zero value.
@@ -491,11 +513,13 @@ def __bytes__(self) -> bytes:
491
513
# treat it like a field of raw bytes.
492
514
buf = b""
493
515
for item in value :
494
- buf += _preprocess_single (meta .proto_type , item )
516
+ buf += _preprocess_single (meta .proto_type , "" , item )
495
517
output += _serialize_single (meta .number , TYPE_BYTES , buf )
496
518
else :
497
519
for item in value :
498
- output += _serialize_single (meta .number , meta .proto_type , item )
520
+ output += _serialize_single (
521
+ meta .number , meta .proto_type , item , wraps = meta .wraps
522
+ )
499
523
elif isinstance (value , dict ):
500
524
for k , v in value .items ():
501
525
assert meta .map_types
@@ -504,7 +528,11 @@ def __bytes__(self) -> bytes:
504
528
output += _serialize_single (meta .number , meta .proto_type , sk + sv )
505
529
else :
506
530
output += _serialize_single (
507
- meta .number , meta .proto_type , value , serialize_empty = serialize_empty
531
+ meta .number ,
532
+ meta .proto_type ,
533
+ value ,
534
+ serialize_empty = serialize_empty ,
535
+ wraps = meta .wraps ,
508
536
)
509
537
510
538
return output + self ._unknown_fields
@@ -546,7 +574,7 @@ def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> A
546
574
value = 0
547
575
elif t == datetime :
548
576
# Offsets are relative to 1970-01-01T00:00:00Z
549
- value = datetime ( 1970 , 1 , 1 , tzinfo = timezone . utc )
577
+ value = DATETIME_ZERO
550
578
else :
551
579
# This is either a primitive scalar or another message type. Calling
552
580
# it should result in its zero value.
@@ -580,6 +608,10 @@ def _postprocess_single(
580
608
value = _Timestamp ().parse (value ).to_datetime ()
581
609
elif cls == timedelta :
582
610
value = _Duration ().parse (value ).to_timedelta ()
611
+ elif meta .wraps :
612
+ # This is a Google wrapper value message around a single
613
+ # scalar type.
614
+ value = _get_wrapper (meta .wraps )().parse (value ).value
583
615
else :
584
616
value = cls ().parse (value )
585
617
value ._serialized_on_wire = True
@@ -670,9 +702,14 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
670
702
cased_name = casing (field .name ).rstrip ("_" )
671
703
if meta .proto_type == "message" :
672
704
if isinstance (v , datetime ):
673
- output [cased_name ] = _Timestamp .to_json (v )
705
+ if v != DATETIME_ZERO :
706
+ output [cased_name ] = _Timestamp .to_json (v )
674
707
elif isinstance (v , timedelta ):
675
- output [cased_name ] = _Duration .to_json (v )
708
+ if v != timedelta (0 ):
709
+ output [cased_name ] = _Duration .to_json (v )
710
+ elif meta .wraps :
711
+ if v is not None :
712
+ output [cased_name ] = v
676
713
elif isinstance (v , list ):
677
714
# Convert each item.
678
715
v = [i .to_dict () for i in v ]
@@ -723,17 +760,20 @@ def from_dict(self: T, value: dict) -> T:
723
760
if value [key ] is not None :
724
761
if meta .proto_type == "message" :
725
762
v = getattr (self , field .name )
726
- # print(v, value[key])
727
763
if isinstance (v , list ):
728
764
cls = self ._cls_for (field )
729
765
for i in range (len (value [key ])):
730
766
v .append (cls ().from_dict (value [key ][i ]))
731
767
elif isinstance (v , datetime ):
732
- v = datetime .fromisoformat (value [key ].replace ("Z" , "+00:00" ))
768
+ v = datetime .fromisoformat (
769
+ value [key ].replace ("Z" , "+00:00" )
770
+ )
733
771
setattr (self , field .name , v )
734
772
elif isinstance (v , timedelta ):
735
773
v = timedelta (seconds = float (value [key ][:- 1 ]))
736
774
setattr (self , field .name , v )
775
+ elif meta .wraps :
776
+ setattr (self , field .name , value [key ])
737
777
else :
738
778
v .from_dict (value [key ])
739
779
elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
@@ -830,7 +870,6 @@ class _Timestamp(Message):
830
870
831
871
def to_datetime (self ) -> datetime :
832
872
ts = self .seconds + (self .nanos / 1e9 )
833
- print ('to-datetime' , ts , datetime .fromtimestamp (ts , tz = timezone .utc ))
834
873
return datetime .fromtimestamp (ts , tz = timezone .utc )
835
874
836
875
@staticmethod
@@ -839,17 +878,90 @@ def to_json(dt: datetime) -> str:
839
878
copy = dt .replace (microsecond = 0 , tzinfo = None )
840
879
result = copy .isoformat ()
841
880
if (nanos % 1e9 ) == 0 :
842
- # If there are 0 fractional digits, the fractional
843
- # point '.' should be omitted when serializing.
844
- return result + 'Z'
881
+ # If there are 0 fractional digits, the fractional
882
+ # point '.' should be omitted when serializing.
883
+ return result + "Z"
845
884
if (nanos % 1e6 ) == 0 :
846
- # Serialize 3 fractional digits.
847
- return result + ' .%03dZ' % (nanos / 1e6 )
885
+ # Serialize 3 fractional digits.
886
+ return result + " .%03dZ" % (nanos / 1e6 )
848
887
if (nanos % 1e3 ) == 0 :
849
- # Serialize 6 fractional digits.
850
- return result + ' .%06dZ' % (nanos / 1e3 )
888
+ # Serialize 6 fractional digits.
889
+ return result + " .%06dZ" % (nanos / 1e3 )
851
890
# Serialize 9 fractional digits.
852
- return result + '.%09dZ' % nanos
891
+ return result + ".%09dZ" % nanos
892
+
893
+
894
+ class _WrappedMessage (Message ):
895
+ """
896
+ Google protobuf wrapper types base class. JSON representation is just the
897
+ value itself.
898
+ """
899
+ def to_dict (self ) -> Any :
900
+ return self .value
901
+
902
+ def from_dict (self , value : Any ) -> None :
903
+ if value is not None :
904
+ self .value = value
905
+
906
+
907
+ @dataclasses .dataclass
908
+ class _BoolValue (_WrappedMessage ):
909
+ value : bool = bool_field (1 )
910
+
911
+
912
+ @dataclasses .dataclass
913
+ class _Int32Value (_WrappedMessage ):
914
+ value : int = int32_field (1 )
915
+
916
+
917
+ @dataclasses .dataclass
918
+ class _UInt32Value (_WrappedMessage ):
919
+ value : int = uint32_field (1 )
920
+
921
+
922
+ @dataclasses .dataclass
923
+ class _Int64Value (_WrappedMessage ):
924
+ value : int = int64_field (1 )
925
+
926
+
927
+ @dataclasses .dataclass
928
+ class _UInt64Value (_WrappedMessage ):
929
+ value : int = uint64_field (1 )
930
+
931
+
932
+ @dataclasses .dataclass
933
+ class _FloatValue (_WrappedMessage ):
934
+ value : float = float_field (1 )
935
+
936
+
937
+ @dataclasses .dataclass
938
+ class _DoubleValue (_WrappedMessage ):
939
+ value : float = double_field (1 )
940
+
941
+
942
+ @dataclasses .dataclass
943
+ class _StringValue (_WrappedMessage ):
944
+ value : str = string_field (1 )
945
+
946
+
947
+ @dataclasses .dataclass
948
+ class _BytesValue (_WrappedMessage ):
949
+ value : bytes = bytes_field (1 )
950
+
951
+
952
+ def _get_wrapper (proto_type : str ) -> _WrappedMessage :
953
+ """Get the wrapper message class for a wrapped type."""
954
+ return {
955
+ TYPE_BOOL : _BoolValue ,
956
+ TYPE_INT32 : _Int32Value ,
957
+ TYPE_UINT32 : _UInt32Value ,
958
+ TYPE_INT64 : _Int64Value ,
959
+ TYPE_UINT64 : _UInt64Value ,
960
+ TYPE_FLOAT : _FloatValue ,
961
+ TYPE_DOUBLE : _DoubleValue ,
962
+ TYPE_STRING : _StringValue ,
963
+ TYPE_BYTES : _BytesValue ,
964
+ }[proto_type ]
853
965
854
966
855
967
class ServiceStub (ABC ):
0 commit comments