26
26
from .casing import camel_case , safe_snake_case , snake_case
27
27
from .grpc .grpclib_client import ServiceStub
28
28
29
- if not ( sys .version_info . major == 3 and sys . version_info . minor >= 7 ):
29
+ if sys .version_info [: 2 ] < ( 3 , 7 ):
30
30
# Apply backport of datetime.fromisoformat from 3.7
31
31
from backports .datetime_fromisoformat import MonkeyPatch
32
32
110
110
111
111
112
112
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
113
- def datetime_default_gen ():
113
+ def datetime_default_gen () -> datetime :
114
114
return datetime (1970 , 1 , 1 , tzinfo = timezone .utc )
115
115
116
116
@@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
256
256
257
257
@classmethod
258
258
def from_string (cls , name : str ) -> "Enum" :
259
- """
260
- Return the value which corresponds to the string name.
259
+ """Return the value which corresponds to the string name.
261
260
262
261
Parameters
263
262
-----------
@@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
316
315
return encode_varint (value )
317
316
elif proto_type in [TYPE_SINT32 , TYPE_SINT64 ]:
318
317
# Handle zig-zag encoding.
319
- if value >= 0 :
320
- value = value << 1
321
- else :
322
- value = (value << 1 ) ^ (~ 0 )
323
- return encode_varint (value )
318
+ return encode_varint (value << 1 if value >= 0 else (value << 1 ) ^ (~ 0 ))
324
319
elif proto_type in FIXED_TYPES :
325
320
return struct .pack (_pack_fmt (proto_type ), value )
326
321
elif proto_type == TYPE_STRING :
@@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
413
408
wire_type = num_wire & 0x7
414
409
415
410
decoded : Any = None
416
- if wire_type == 0 :
411
+ if wire_type == WIRE_VARINT :
417
412
decoded , i = decode_varint (value , i )
418
- elif wire_type == 1 :
413
+ elif wire_type == WIRE_FIXED_64 :
419
414
decoded , i = value [i : i + 8 ], i + 8
420
- elif wire_type == 2 :
415
+ elif wire_type == WIRE_LEN_DELIM :
421
416
length , i = decode_varint (value , i )
422
417
decoded = value [i : i + length ]
423
418
i += length
424
- elif wire_type == 5 :
419
+ elif wire_type == WIRE_FIXED_32 :
425
420
decoded , i = value [i : i + 4 ], i + 4
426
421
427
422
yield ParsedField (
@@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
430
425
431
426
432
427
class ProtoClassMetadata :
433
- oneof_group_by_field : Dict [str , str ]
434
- oneof_field_by_group : Dict [str , Set [dataclasses .Field ]]
435
- default_gen : Dict [str , Callable ]
436
- cls_by_field : Dict [str , Type ]
437
- field_name_by_number : Dict [int , str ]
438
- meta_by_field_name : Dict [str , FieldMetadata ]
439
428
__slots__ = (
440
429
"oneof_group_by_field" ,
441
430
"oneof_field_by_group" ,
@@ -446,6 +435,14 @@ class ProtoClassMetadata:
446
435
"sorted_field_names" ,
447
436
)
448
437
438
+ oneof_group_by_field : Dict [str , str ]
439
+ oneof_field_by_group : Dict [str , Set [dataclasses .Field ]]
440
+ field_name_by_number : Dict [int , str ]
441
+ meta_by_field_name : Dict [str , FieldMetadata ]
442
+ sorted_field_names : Tuple [str , ...]
443
+ default_gen : Dict [str , Callable [[], Any ]]
444
+ cls_by_field : Dict [str , Type ]
445
+
449
446
def __init__ (self , cls : Type ["Message" ]):
450
447
by_field = {}
451
448
by_group : Dict [str , Set ] = {}
@@ -470,23 +467,21 @@ def __init__(self, cls: Type["Message"]):
470
467
self .field_name_by_number = by_field_number
471
468
self .meta_by_field_name = by_field_name
472
469
self .sorted_field_names = tuple (
473
- by_field_number [number ] for number in sorted (by_field_number . keys () )
470
+ by_field_number [number ] for number in sorted (by_field_number )
474
471
)
475
-
476
472
self .default_gen = self ._get_default_gen (cls , fields )
477
473
self .cls_by_field = self ._get_cls_by_field (cls , fields )
478
474
479
475
@staticmethod
480
- def _get_default_gen (cls , fields ):
481
- default_gen = {}
482
-
483
- for field in fields :
484
- default_gen [field .name ] = cls ._get_field_default_gen (field )
485
-
486
- return default_gen
476
+ def _get_default_gen (
477
+ cls : Type ["Message" ], fields : List [dataclasses .Field ]
478
+ ) -> Dict [str , Callable [[], Any ]]:
479
+ return {field .name : cls ._get_field_default_gen (field ) for field in fields }
487
480
488
481
@staticmethod
489
- def _get_cls_by_field (cls , fields ):
482
+ def _get_cls_by_field (
483
+ cls : Type ["Message" ], fields : List [dataclasses .Field ]
484
+ ) -> Dict [str , Type ]:
490
485
field_cls = {}
491
486
492
487
for field in fields :
@@ -503,7 +498,7 @@ def _get_cls_by_field(cls, fields):
503
498
],
504
499
bases = (Message ,),
505
500
)
506
- field_cls [field .name + " .value" ] = vt
501
+ field_cls [f" { field .name } .value" ] = vt
507
502
else :
508
503
field_cls [field .name ] = cls ._cls_for (field )
509
504
@@ -612,7 +607,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
612
607
super ().__setattr__ (attr , value )
613
608
614
609
@property
615
- def _betterproto (self ):
610
+ def _betterproto (self ) -> ProtoClassMetadata :
616
611
"""
617
612
Lazy initialize metadata for each protobuf class.
618
613
It may be initialized multiple times in a multi-threaded environment,
@@ -726,9 +721,8 @@ def _type_hint(cls, field_name: str) -> Type:
726
721
727
722
@classmethod
728
723
def _type_hints (cls ) -> Dict [str , Type ]:
729
- module = inspect .getmodule (cls )
730
- type_hints = get_type_hints (cls , vars (module ))
731
- return type_hints
724
+ module = sys .modules [cls .__module__ ]
725
+ return get_type_hints (cls , vars (module ))
732
726
733
727
@classmethod
734
728
def _cls_for (cls , field : dataclasses .Field , index : int = 0 ) -> Type :
@@ -739,7 +733,7 @@ def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
739
733
field_cls = field_cls .__args__ [index ]
740
734
return field_cls
741
735
742
- def _get_field_default (self , field_name ) :
736
+ def _get_field_default (self , field_name : str ) -> Any :
743
737
return self ._betterproto .default_gen [field_name ]()
744
738
745
739
@classmethod
@@ -762,7 +756,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
762
756
elif issubclass (t , Enum ):
763
757
# Enums always default to zero.
764
758
return int
765
- elif t == datetime :
759
+ elif t is datetime :
766
760
# Offsets are relative to 1970-01-01T00:00:00Z
767
761
return datetime_default_gen
768
762
else :
@@ -966,7 +960,7 @@ def to_dict(
966
960
)
967
961
):
968
962
output [cased_name ] = value .to_dict (casing , include_default_values )
969
- elif meta .proto_type == "map" :
963
+ elif meta .proto_type == TYPE_MAP :
970
964
for k in value :
971
965
if hasattr (value [k ], "to_dict" ):
972
966
value [k ] = value [k ].to_dict (casing , include_default_values )
@@ -1032,12 +1026,12 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
1032
1026
continue
1033
1027
1034
1028
if value [key ] is not None :
1035
- if meta .proto_type == "message" :
1029
+ if meta .proto_type == TYPE_MESSAGE :
1036
1030
v = getattr (self , field_name )
1037
1031
if isinstance (v , list ):
1038
1032
cls = self ._betterproto .cls_by_field [field_name ]
1039
- for i in range ( len ( value [key ])) :
1040
- v .append (cls ().from_dict (value [ key ][ i ] ))
1033
+ for item in value [key ]:
1034
+ v .append (cls ().from_dict (item ))
1041
1035
elif isinstance (v , datetime ):
1042
1036
v = datetime .fromisoformat (value [key ].replace ("Z" , "+00:00" ))
1043
1037
setattr (self , field_name , v )
@@ -1052,7 +1046,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
1052
1046
v .from_dict (value [key ])
1053
1047
elif meta .map_types and meta .map_types [1 ] == TYPE_MESSAGE :
1054
1048
v = getattr (self , field_name )
1055
- cls = self ._betterproto .cls_by_field [field_name + " .value" ]
1049
+ cls = self ._betterproto .cls_by_field [f" { field_name } .value" ]
1056
1050
for k in value [key ]:
1057
1051
v [k ] = cls ().from_dict (value [key ][k ])
1058
1052
else :
@@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
1134
1128
return message ._serialized_on_wire
1135
1129
1136
1130
1137
- def which_one_of (message : Message , group_name : str ) -> Tuple [str , Any ]:
1131
+ def which_one_of (message : Message , group_name : str ) -> Tuple [str , Optional [ Any ] ]:
1138
1132
"""
1139
1133
Return the name and value of a message's one-of field group.
1140
1134
@@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
1145
1139
"""
1146
1140
field_name = message ._group_current .get (group_name )
1147
1141
if not field_name :
1148
- return ( "" , None )
1149
- return ( field_name , getattr (message , field_name ) )
1142
+ return "" , None
1143
+ return field_name , getattr (message , field_name )
1150
1144
1151
1145
1152
1146
# Circular import workaround: google.protobuf depends on base classes defined above.
1153
1147
from .lib .google .protobuf import ( # noqa
1154
- Duration ,
1155
- Timestamp ,
1156
1148
BoolValue ,
1157
1149
BytesValue ,
1158
1150
DoubleValue ,
1151
+ Duration ,
1159
1152
FloatValue ,
1160
1153
Int32Value ,
1161
1154
Int64Value ,
1162
1155
StringValue ,
1156
+ Timestamp ,
1163
1157
UInt32Value ,
1164
1158
UInt64Value ,
1165
1159
)
@@ -1174,8 +1168,8 @@ def delta_to_json(delta: timedelta) -> str:
1174
1168
parts = str (delta .total_seconds ()).split ("." )
1175
1169
if len (parts ) > 1 :
1176
1170
while len (parts [1 ]) not in [3 , 6 , 9 ]:
1177
- parts [1 ] = parts [1 ] + " 0"
1178
- return "." .join (parts ) + " s"
1171
+ parts [1 ] = f" { parts [1 ]} 0"
1172
+ return f" { '.' .join (parts )} s"
1179
1173
1180
1174
1181
1175
class _Timestamp (Timestamp ):
@@ -1191,15 +1185,15 @@ def timestamp_to_json(dt: datetime) -> str:
1191
1185
if (nanos % 1e9 ) == 0 :
1192
1186
# If there are 0 fractional digits, the fractional
1193
1187
# point '.' should be omitted when serializing.
1194
- return result + " Z"
1188
+ return f" { result } Z"
1195
1189
if (nanos % 1e6 ) == 0 :
1196
1190
# Serialize 3 fractional digits.
1197
- return result + ".%03dZ" % (nanos / 1e6 )
1191
+ return f" { result } . { int (nanos // 1e6 ) :03d } Z"
1198
1192
if (nanos % 1e3 ) == 0 :
1199
1193
# Serialize 6 fractional digits.
1200
- return result + ".%06dZ" % (nanos / 1e3 )
1194
+ return f" { result } . { int (nanos // 1e3 ) :06d } Z"
1201
1195
# Serialize 9 fractional digits.
1202
- return result + ".%09dZ" % nanos
1196
+ return f" { result } . { nanos :09d } "
1203
1197
1204
1198
1205
1199
class _WrappedMessage (Message ):
0 commit comments