1
- from abc import ABC
1
+ import dataclasses
2
+ import inspect
2
3
import json
3
4
import struct
5
+ from abc import ABC
4
6
from typing import (
5
- get_type_hints ,
7
+ Any ,
6
8
AsyncGenerator ,
7
- Union ,
9
+ Callable ,
10
+ Dict ,
8
11
Generator ,
9
- Any ,
10
- SupportsBytes ,
12
+ Iterable ,
11
13
List ,
14
+ Optional ,
15
+ SupportsBytes ,
12
16
Tuple ,
13
- Callable ,
14
17
Type ,
15
- Iterable ,
16
18
TypeVar ,
17
- Optional ,
19
+ Union ,
20
+ get_type_hints ,
18
21
)
19
- import dataclasses
20
22
21
23
import grpclib .client
22
24
import grpclib .const
23
25
24
- import inspect
25
-
26
26
# Proto 3 data types
27
27
TYPE_ENUM = "enum"
28
28
TYPE_BOOL = "bool"
54
54
TYPE_SFIXED64 ,
55
55
]
56
56
57
+ # Fields that are numerical 64-bit types
58
+ INT_64_TYPES = [TYPE_INT64 , TYPE_UINT64 , TYPE_SINT64 , TYPE_FIXED64 , TYPE_SFIXED64 ]
59
+
57
60
# Fields that are efficiently packed when
58
61
PACKED_TYPES = [
59
62
TYPE_ENUM ,
@@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
275
278
return value
276
279
277
280
278
- def _serialize_single (field_number : int , proto_type : str , value : Any ) -> bytes :
281
+ def _serialize_single (
282
+ field_number : int , proto_type : str , value : Any , * , serialize_empty : bool = False
283
+ ) -> bytes :
279
284
"""Serializes a single field and value."""
280
285
value = _preprocess_single (proto_type , value )
281
286
@@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
290
295
key = encode_varint ((field_number << 3 ) | 1 )
291
296
output += key + value
292
297
elif proto_type in WIRE_LEN_DELIM_TYPES :
293
- if len (value ):
298
+ if len (value ) or serialize_empty :
294
299
key = encode_varint ((field_number << 3 ) | 2 )
295
300
output += key + encode_varint (len (value )) + value
296
301
else :
@@ -362,6 +367,11 @@ class Message(ABC):
362
367
to go between Python, binary and JSON protobuf message representations.
363
368
"""
364
369
370
+ # True if this message was or should be serialized on the wire. This can
371
+ # be used to detect presence (e.g. optional wrapper message) and is used
372
+ # internally during parsing/serialization.
373
+ serialized_on_wire : bool
374
+
365
375
def __post_init__ (self ) -> None :
366
376
# Set a default value for each field in the class after `__init__` has
367
377
# already been run.
@@ -389,6 +399,15 @@ def __post_init__(self) -> None:
389
399
390
400
setattr (self , field .name , value )
391
401
402
+ # Now that all the defaults are set, reset it!
403
+ self .__dict__ ["serialized_on_wire" ] = False
404
+
405
+ def __setattr__ (self , attr : str , value : Any ) -> None :
406
+ if attr != "serialized_on_wire" :
407
+ # Track when a field has been set.
408
+ self .__dict__ ["serialized_on_wire" ] = True
409
+ super ().__setattr__ (attr , value )
410
+
392
411
def __bytes__ (self ) -> bytes :
393
412
"""
394
413
Get the binary encoded Protobuf representation of this instance.
@@ -429,7 +448,12 @@ def __bytes__(self) -> bytes:
429
448
# Default (zero) values are not serialized
430
449
continue
431
450
432
- output += _serialize_single (meta .number , meta .proto_type , value )
451
+ serialize_empty = False
452
+ if isinstance (value , Message ) and value .serialized_on_wire :
453
+ serialize_empty = True
454
+ output += _serialize_single (
455
+ meta .number , meta .proto_type , value , serialize_empty = serialize_empty
456
+ )
433
457
434
458
return output
435
459
@@ -462,12 +486,13 @@ def _postprocess_single(
462
486
fmt = _pack_fmt (meta .proto_type )
463
487
value = struct .unpack (fmt , value )[0 ]
464
488
elif wire_type == WIRE_LEN_DELIM :
465
- if meta .proto_type in [ TYPE_STRING ] :
489
+ if meta .proto_type == TYPE_STRING :
466
490
value = value .decode ("utf-8" )
467
- elif meta .proto_type in [ TYPE_MESSAGE ] :
491
+ elif meta .proto_type == TYPE_MESSAGE :
468
492
cls = self ._cls_for (field )
469
493
value = cls ().parse (value )
470
- elif meta .proto_type in [TYPE_MAP ]:
494
+ value .serialized_on_wire = True
495
+ elif meta .proto_type == TYPE_MAP :
471
496
# TODO: This is slow, use a cache to make it faster since each
472
497
# key/value pair will recreate the class.
473
498
assert meta .map_types
@@ -535,8 +560,6 @@ def parse(self: T, data: bytes) -> T:
535
560
# TODO: handle unknown fields
536
561
pass
537
562
538
- from typing import cast
539
-
540
563
return self
541
564
542
565
# For compatibility with other libraries.
@@ -549,21 +572,17 @@ def to_dict(self) -> dict:
549
572
Returns a dict representation of this message instance which can be
550
573
used to serialize to e.g. JSON.
551
574
"""
552
- output = {}
575
+ output : Dict [ str , Any ] = {}
553
576
for field in dataclasses .fields (self ):
554
577
meta = FieldMetadata .get (field )
555
578
v = getattr (self , field .name )
556
579
if meta .proto_type == "message" :
557
580
if isinstance (v , list ):
558
581
# Convert each item.
559
582
v = [i .to_dict () for i in v ]
560
- # Filter out empty items which we won't serialize.
561
- v = [i for i in v if i ]
562
- else :
563
- v = v .to_dict ()
564
-
565
- if v :
566
583
output [field .name ] = v
584
+ elif v .serialized_on_wire :
585
+ output [field .name ] = v .to_dict ()
567
586
elif meta .proto_type == "map" :
568
587
for k in v :
569
588
if hasattr (v [k ], "to_dict" ):
@@ -572,14 +591,21 @@ def to_dict(self) -> dict:
572
591
if v :
573
592
output [field .name ] = v
574
593
elif v != get_default (meta .proto_type ):
575
- output [field .name ] = v
594
+ if meta .proto_type in INT_64_TYPES :
595
+ if isinstance (v , list ):
596
+ output [field .name ] = [str (n ) for n in v ]
597
+ else :
598
+ output [field .name ] = str (v )
599
+ else :
600
+ output [field .name ] = v
576
601
return output
577
602
578
603
def from_dict (self : T , value : dict ) -> T :
579
604
"""
580
605
Parse the key/value pairs in `value` into this message instance. This
581
606
returns the instance itself and is therefore assignable and chainable.
582
607
"""
608
+ self .serialized_on_wire = True
583
609
for field in dataclasses .fields (self ):
584
610
meta = FieldMetadata .get (field )
585
611
if field .name in value and value [field .name ] is not None :
@@ -598,7 +624,13 @@ def from_dict(self: T, value: dict) -> T:
598
624
for k in value [field .name ]:
599
625
v [k ] = cls ().from_dict (value [field .name ][k ])
600
626
else :
601
- setattr (self , field .name , value [field .name ])
627
+ v = value [field .name ]
628
+ if meta .proto_type in INT_64_TYPES :
629
+ if isinstance (value [field .name ], list ):
630
+ v = [int (n ) for n in value [field .name ]]
631
+ else :
632
+ v = int (value [field .name ])
633
+ setattr (self , field .name , v )
602
634
return self
603
635
604
636
def to_json (self ) -> str :
@@ -613,9 +645,6 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
613
645
return self .from_dict (json .loads (value ))
614
646
615
647
616
- ResponseType = TypeVar ("ResponseType" , bound = "Message" )
617
-
618
-
619
648
class ServiceStub (ABC ):
620
649
"""
621
650
Base class for async gRPC service stubs.
0 commit comments