2
2
import json
3
3
import struct
4
4
from typing import (
5
+ get_type_hints ,
5
6
Union ,
6
7
Generator ,
7
8
Any ,
15
16
)
16
17
import dataclasses
17
18
19
+ import inspect
20
+
18
21
# Proto 3 data types
19
22
TYPE_ENUM = "enum"
20
23
TYPE_BOOL = "bool"
@@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
283
286
raise ValueError ("Too many bytes when decoding varint." )
284
287
285
288
286
- def _postprocess_single (
287
- wire_type : int , meta : FieldMetadata , field : Any , value : Any
288
- ) -> Any :
289
- """Adjusts values after parsing."""
290
- if wire_type == WIRE_VARINT :
291
- if meta .proto_type in ["int32" , "int64" ]:
292
- bits = int (meta .proto_type [3 :])
293
- value = value & ((1 << bits ) - 1 )
294
- signbit = 1 << (bits - 1 )
295
- value = int ((value ^ signbit ) - signbit )
296
- elif meta .proto_type in ["sint32" , "sint64" ]:
297
- # Undo zig-zag encoding
298
- value = (value >> 1 ) ^ (- (value & 1 ))
299
- elif wire_type in [WIRE_FIXED_32 , WIRE_FIXED_64 ]:
300
- fmt = _pack_fmt (meta .proto_type )
301
- value = struct .unpack (fmt , value )[0 ]
302
- elif wire_type == WIRE_LEN_DELIM :
303
- if meta .proto_type in ["string" ]:
304
- value = value .decode ("utf-8" )
305
- elif meta .proto_type in ["message" ]:
306
- orig = value
307
- value = field .default_factory ()
308
- if isinstance (value , Message ):
309
- # If it's a message (instead of e.g. list) then keep going!
310
- value .parse (orig )
311
-
312
- return value
313
-
314
-
315
289
@dataclasses .dataclass (frozen = True )
316
290
class ParsedField :
317
291
number : int
@@ -388,6 +362,41 @@ def __bytes__(self) -> bytes:
388
362
389
363
return output
390
364
365
+ def _cls_for (self , field : dataclasses .Field ) -> Type :
366
+ """Get the message class for a field from the type hints."""
367
+ module = inspect .getmodule (self )
368
+ type_hints = get_type_hints (self , vars (module ))
369
+ cls = type_hints [field .name ]
370
+ if hasattr (cls , "__args__" ):
371
+ print (type_hints [field .name ].__args__ [0 ])
372
+ cls = type_hints [field .name ].__args__ [0 ]
373
+ return cls
374
+
375
+ def _postprocess_single (
376
+ self , wire_type : int , meta : FieldMetadata , field : dataclasses .Field , value : Any
377
+ ) -> Any :
378
+ """Adjusts values after parsing."""
379
+ if wire_type == WIRE_VARINT :
380
+ if meta .proto_type in ["int32" , "int64" ]:
381
+ bits = int (meta .proto_type [3 :])
382
+ value = value & ((1 << bits ) - 1 )
383
+ signbit = 1 << (bits - 1 )
384
+ value = int ((value ^ signbit ) - signbit )
385
+ elif meta .proto_type in ["sint32" , "sint64" ]:
386
+ # Undo zig-zag encoding
387
+ value = (value >> 1 ) ^ (- (value & 1 ))
388
+ elif wire_type in [WIRE_FIXED_32 , WIRE_FIXED_64 ]:
389
+ fmt = _pack_fmt (meta .proto_type )
390
+ value = struct .unpack (fmt , value )[0 ]
391
+ elif wire_type == WIRE_LEN_DELIM :
392
+ if meta .proto_type in ["string" ]:
393
+ value = value .decode ("utf-8" )
394
+ elif meta .proto_type in ["message" ]:
395
+ cls = self ._cls_for (field )
396
+ value = cls ().parse (value )
397
+
398
+ return value
399
+
391
400
def parse (self , data : bytes ) -> T :
392
401
"""
393
402
Parse the binary encoded Protobuf into this message instance. This
@@ -416,10 +425,12 @@ def parse(self, data: bytes) -> T:
416
425
else :
417
426
decoded , pos = decode_varint (parsed .value , pos )
418
427
wire_type = WIRE_VARINT
419
- decoded = _postprocess_single (wire_type , meta , field , decoded )
428
+ decoded = self ._postprocess_single (
429
+ wire_type , meta , field , decoded
430
+ )
420
431
value .append (decoded )
421
432
else :
422
- value = _postprocess_single (
433
+ value = self . _postprocess_single (
423
434
parsed .wire_type , meta , field , parsed .value
424
435
)
425
436
@@ -445,7 +456,13 @@ def to_dict(self) -> dict:
445
456
meta = FieldMetadata .get (field )
446
457
v = getattr (self , field .name )
447
458
if meta .proto_type == "message" :
448
- v = v .to_dict ()
459
+ if isinstance (v , list ):
460
+ # Convert each item.
461
+ v = [i .to_dict () for i in v ]
462
+ # Filter out empty items which we won't serialize.
463
+ v = [i for i in v if i ]
464
+ else :
465
+ v = v .to_dict ()
449
466
if v :
450
467
output [field .name ] = v
451
468
elif v != field .default :
@@ -461,7 +478,14 @@ def from_dict(self, value: dict) -> T:
461
478
meta = FieldMetadata .get (field )
462
479
if field .name in value :
463
480
if meta .proto_type == "message" :
464
- getattr (self , field .name ).from_dict (value [field .name ])
481
+ v = getattr (self , field .name )
482
+ print (v , value [field .name ])
483
+ if isinstance (v , list ):
484
+ cls = self ._cls_for (field )
485
+ for i in range (len (value [field .name ])):
486
+ v .append (cls ().from_dict (value [field .name ][i ]))
487
+ else :
488
+ v .from_dict (value [field .name ])
465
489
else :
466
490
setattr (self , field .name , value [field .name ])
467
491
return self
0 commit comments