120
120
121
121
122
122
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
123
- DATETIME_ZERO = datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
123
+ def datetime_default_gen ():
124
+ return datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
125
+
126
+
127
+ DATETIME_ZERO = datetime_default_gen ()
124
128
125
129
126
130
class Casing (enum .Enum ):
@@ -428,6 +432,57 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
428
432
T = TypeVar ("T" , bound = "Message" )
429
433
430
434
435
+ class ProtoClassMetadata :
436
+ cls : "Message"
437
+
438
+ def __init__ (self , cls : "Message" ):
439
+ self .cls = cls
440
+ by_field = {}
441
+ by_group = {}
442
+
443
+ for field in dataclasses .fields (cls ):
444
+ meta = FieldMetadata .get (field )
445
+
446
+ if meta .group :
447
+ # This is part of a one-of group.
448
+ by_field [field .name ] = meta .group
449
+
450
+ by_group .setdefault (meta .group , set ()).add (field )
451
+
452
+ self .oneof_group_by_field = by_field
453
+ self .oneof_field_by_group = by_group
454
+
455
+ def __getattr__ (self , item ):
456
+ # Lazy init because forward reference classes may not be available at the beginning.
457
+ if item == 'default_gen' :
458
+ defaults = {}
459
+ for field in dataclasses .fields (self .cls ):
460
+ meta = FieldMetadata .get (field )
461
+ defaults [field .name ] = self .cls ._get_field_default_gen (field , meta )
462
+
463
+ self .default_gen = defaults # __getattr__ won't be called next time
464
+ return defaults
465
+
466
+ if item == 'cls_by_field' :
467
+ field_cls = {}
468
+ for field in dataclasses .fields (self .cls ):
469
+ meta = FieldMetadata .get (field )
470
+ field_cls [field .name ] = self .cls ._type_hint (field .name )
471
+
472
+ self .cls_by_field = field_cls # __getattr__ won't be called next time
473
+ return field_cls
474
+
475
+
476
+ def make_protoclass (cls ):
477
+ setattr (cls , "_betterproto" , ProtoClassMetadata (cls ))
478
+
479
+
480
+ def protoclass (* args , ** kwargs ):
481
+ cls = dataclasses .dataclass (* args , ** kwargs )
482
+ make_protoclass (cls )
483
+ return cls
484
+
485
+
431
486
class Message (ABC ):
432
487
"""
433
488
A protobuf message base class. Generated code will inherit from this and
@@ -445,25 +500,20 @@ def __post_init__(self) -> None:
445
500
446
501
# Set a default value for each field in the class after `__init__` has
447
502
# already been run.
448
- group_map : Dict [str , dict ] = {"fields" : {}, "groups" : {} }
503
+ group_map : Dict [str , dataclasses . Field ] = {}
449
504
for field in dataclasses .fields (self ):
450
505
meta = FieldMetadata .get (field )
451
506
452
507
if meta .group :
453
- # This is part of a one-of group.
454
- group_map ["fields" ][field .name ] = meta .group
455
-
456
- if meta .group not in group_map ["groups" ]:
457
- group_map ["groups" ][meta .group ] = {"current" : None , "fields" : set ()}
458
- group_map ["groups" ][meta .group ]["fields" ].add (field )
508
+ group_map .setdefault (meta .group )
459
509
460
510
if getattr (self , field .name ) != PLACEHOLDER :
461
511
# Skip anything not set to the sentinel value
462
512
all_sentinel = False
463
513
464
514
if meta .group :
465
515
# This was set, so make it the selected value of the one-of.
466
- group_map ["groups" ][ meta .group ][ "current" ] = field
516
+ group_map [meta .group ] = field
467
517
468
518
continue
469
519
@@ -479,16 +529,17 @@ def __setattr__(self, attr: str, value: Any) -> None:
479
529
# Track when a field has been set.
480
530
self .__dict__ ["_serialized_on_wire" ] = True
481
531
482
- if attr in getattr (self , "_group_map" , {}).get ("fields" , {}):
483
- group = self ._group_map ["fields" ][attr ]
484
- for field in self ._group_map ["groups" ][group ]["fields" ]:
485
- if field .name == attr :
486
- self ._group_map ["groups" ][group ]["current" ] = field
487
- else :
488
- super ().__setattr__ (
489
- field .name ,
490
- self ._get_field_default (field , FieldMetadata .get (field )),
491
- )
532
+ if hasattr (self , "_group_map" ): # __post_init__ had already run
533
+ if attr in self ._betterproto .oneof_group_by_field :
534
+ group = self ._betterproto .oneof_group_by_field [attr ]
535
+ for field in self ._betterproto .oneof_field_by_group [group ]:
536
+ if field .name == attr :
537
+ self ._group_map [group ] = field
538
+ else :
539
+ super ().__setattr__ (
540
+ field .name ,
541
+ self ._get_field_default (field , FieldMetadata .get (field )),
542
+ )
492
543
493
544
super ().__setattr__ (attr , value )
494
545
@@ -510,7 +561,7 @@ def __bytes__(self) -> bytes:
510
561
# currently set in a `oneof` group, so it must be serialized even
511
562
# if the value is the default zero value.
512
563
selected_in_group = False
513
- if meta .group and self ._group_map ["groups" ][ meta .group ][ "current" ] == field :
564
+ if meta .group and self ._group_map [meta .group ] == field :
514
565
selected_in_group = True
515
566
516
567
serialize_empty = False
@@ -562,47 +613,49 @@ def __bytes__(self) -> bytes:
562
613
# For compatibility with other libraries
563
614
SerializeToString = __bytes__
564
615
565
- def _type_hint (self , field_name : str ) -> Type :
566
- module = inspect .getmodule (self .__class__ )
567
- type_hints = get_type_hints (self .__class__ , vars (module ))
616
+ @classmethod
617
+ def _type_hint (cls , field_name : str ) -> Type :
618
+ module = inspect .getmodule (cls )
619
+ type_hints = get_type_hints (cls , vars (module ))
568
620
return type_hints [field_name ]
569
621
570
622
def _cls_for (self , field : dataclasses .Field , index : int = 0 ) -> Type :
571
623
"""Get the message class for a field from the type hints."""
572
- cls = self ._type_hint ( field .name )
624
+ cls = self ._betterproto . cls_by_field [ field .name ]
573
625
if hasattr (cls , "__args__" ) and index >= 0 :
574
626
cls = cls .__args__ [index ]
575
627
return cls
576
628
577
629
def _get_field_default (self , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
578
- t = self ._type_hint (field .name )
630
+ return self ._betterproto .default_gen [field .name ]()
631
+
632
+ @classmethod
633
+ def _get_field_default_gen (cls , field : dataclasses .Field , meta : FieldMetadata ) -> Any :
634
+ t = cls ._type_hint (field .name )
579
635
580
- value : Any = 0
581
636
if hasattr (t , "__origin__" ):
582
637
if t .__origin__ in (dict , Dict ):
583
638
# This is some kind of map (dict in Python).
584
- value = {}
639
+ return dict
585
640
elif t .__origin__ in (list , List ):
586
641
# This is some kind of list (repeated) field.
587
- value = []
642
+ return list
588
643
elif t .__origin__ == Union and t .__args__ [1 ] == type (None ):
589
644
# This is an optional (wrapped) field. For setting the default we
590
645
# really don't care what kind of field it is.
591
- value = None
646
+ return type ( None )
592
647
else :
593
- value = t ()
648
+ return t
594
649
elif issubclass (t , Enum ):
595
650
# Enums always default to zero.
596
- value = 0
651
+ return int
597
652
elif t == datetime :
598
653
# Offsets are relative to 1970-01-01T00:00:00Z
599
- value = DATETIME_ZERO
654
+ return datetime_default_gen
600
655
else :
601
656
# This is either a primitive scalar or another message type. Calling
602
657
# it should result in its zero value.
603
- value = t ()
604
-
605
- return value
658
+ return t
606
659
607
660
def _postprocess_single (
608
661
self , wire_type : int , meta : FieldMetadata , field : dataclasses .Field , value : Any
@@ -654,6 +707,7 @@ def _postprocess_single(
654
707
],
655
708
bases = (Message ,),
656
709
)
710
+ make_protoclass (Entry )
657
711
value = Entry ().parse (value )
658
712
659
713
return value
@@ -861,13 +915,13 @@ def serialized_on_wire(message: Message) -> bool:
861
915
862
916
def which_one_of (message : Message , group_name : str ) -> Tuple [str , Any ]:
863
917
"""Return the name and value of a message's one-of field group."""
864
- field = message ._group_map [ "groups" ] .get (group_name , {}). get ( "current" )
918
+ field = message ._group_map .get (group_name )
865
919
if not field :
866
920
return ("" , None )
867
921
return (field .name , getattr (message , field .name ))
868
922
869
923
870
- @dataclasses . dataclass
924
+ @protoclass
871
925
class _Duration (Message ):
872
926
# Signed seconds of the span of time. Must be from -315,576,000,000 to
873
927
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@@ -892,7 +946,7 @@ def delta_to_json(delta: timedelta) -> str:
892
946
return "." .join (parts ) + "s"
893
947
894
948
895
- @dataclasses . dataclass
949
+ @protoclass
896
950
class _Timestamp (Message ):
897
951
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
898
952
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@@ -942,47 +996,47 @@ def from_dict(self: T, value: Any) -> T:
942
996
return self
943
997
944
998
945
- @dataclasses . dataclass
999
+ @protoclass
946
1000
class _BoolValue (_WrappedMessage ):
947
1001
value : bool = bool_field (1 )
948
1002
949
1003
950
- @dataclasses . dataclass
1004
+ @protoclass
951
1005
class _Int32Value (_WrappedMessage ):
952
1006
value : int = int32_field (1 )
953
1007
954
1008
955
- @dataclasses . dataclass
1009
+ @protoclass
956
1010
class _UInt32Value (_WrappedMessage ):
957
1011
value : int = uint32_field (1 )
958
1012
959
1013
960
- @dataclasses . dataclass
1014
+ @protoclass
961
1015
class _Int64Value (_WrappedMessage ):
962
1016
value : int = int64_field (1 )
963
1017
964
1018
965
- @dataclasses . dataclass
1019
+ @protoclass
966
1020
class _UInt64Value (_WrappedMessage ):
967
1021
value : int = uint64_field (1 )
968
1022
969
1023
970
- @dataclasses . dataclass
1024
+ @protoclass
971
1025
class _FloatValue (_WrappedMessage ):
972
1026
value : float = float_field (1 )
973
1027
974
1028
975
- @dataclasses . dataclass
1029
+ @protoclass
976
1030
class _DoubleValue (_WrappedMessage ):
977
1031
value : float = double_field (1 )
978
1032
979
1033
980
- @dataclasses . dataclass
1034
+ @protoclass
981
1035
class _StringValue (_WrappedMessage ):
982
1036
value : str = string_field (1 )
983
1037
984
1038
985
- @dataclasses . dataclass
1039
+ @protoclass
986
1040
class _BytesValue (_WrappedMessage ):
987
1041
value : bytes = bytes_field (1 )
988
1042
0 commit comments