Skip to content

Commit 8f7af27

Browse files
authored
QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
1 parent bf9412e commit 8f7af27

File tree

16 files changed

+176
-219
lines changed

16 files changed

+176
-219
lines changed

.github/workflows/code-quality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Run Black
1818
uses: lgeiger/black-action@master
1919
with:
20-
args: --check src/ tests/
20+
args: --check src/ tests/ benchmarks/
2121

2222
- name: Install rST dependcies
2323
run: python -m pip install doc8

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ output
1515
.DS_Store
1616
.tox
1717
.venv
18-
.asv
18+
.asv
19+
venv

benchmarks/benchmarks.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,48 +8,44 @@ class TestMessage(betterproto.Message):
88
bar: str = betterproto.string_field(1)
99
baz: float = betterproto.float_field(2)
1010

11+
1112
class BenchMessage:
12-
"""Test creation and usage a proto message.
13-
"""
13+
"""Test creation and usage a proto message."""
1414

1515
def setup(self):
1616
self.cls = TestMessage
1717
self.instance = TestMessage()
1818
self.instance_filled = TestMessage(0, "test", 0.0)
1919

2020
def time_overhead(self):
21-
"""Overhead in class definition.
22-
"""
21+
"""Overhead in class definition."""
22+
2323
@dataclass
2424
class Message(betterproto.Message):
2525
foo: int = betterproto.uint32_field(0)
2626
bar: str = betterproto.string_field(1)
2727
baz: float = betterproto.float_field(2)
2828

2929
def time_instantiation(self):
30-
"""Time instantiation
31-
"""
30+
"""Time instantiation"""
3231
self.cls()
3332

3433
def time_attribute_access(self):
35-
"""Time to access an attribute
36-
"""
34+
"""Time to access an attribute"""
3735
self.instance.foo
3836
self.instance.bar
3937
self.instance.baz
40-
38+
4139
def time_init_with_values(self):
42-
"""Time to set an attribute
43-
"""
40+
"""Time to set an attribute"""
4441
self.cls(0, "test", 0.0)
4542

4643
def time_attribute_setting(self):
47-
"""Time to set attributes
48-
"""
44+
"""Time to set attributes"""
4945
self.instance.foo = 0
5046
self.instance.bar = "test"
5147
self.instance.baz = 0.0
52-
48+
5349
def time_serialize(self):
5450
"""Time serializing a message to wire."""
5551
bytes(self.instance_filled)
@@ -58,6 +54,6 @@ def time_serialize(self):
5854
class MemSuite:
5955
def setup(self):
6056
self.cls = TestMessage
61-
57+
6258
def mem_instance(self):
6359
return self.cls()

src/betterproto/__init__.py

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .casing import camel_case, safe_snake_case, snake_case
2727
from .grpc.grpclib_client import ServiceStub
2828

29-
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
29+
if sys.version_info[:2] < (3, 7):
3030
# Apply backport of datetime.fromisoformat from 3.7
3131
from backports.datetime_fromisoformat import MonkeyPatch
3232

@@ -110,7 +110,7 @@
110110

111111

112112
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
113-
def datetime_default_gen():
113+
def datetime_default_gen() -> datetime:
114114
return datetime(1970, 1, 1, tzinfo=timezone.utc)
115115

116116

@@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
256256

257257
@classmethod
258258
def from_string(cls, name: str) -> "Enum":
259-
"""
260-
Return the value which corresponds to the string name.
259+
"""Return the value which corresponds to the string name.
261260
262261
Parameters
263262
-----------
@@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
316315
return encode_varint(value)
317316
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
318317
# Handle zig-zag encoding.
319-
if value >= 0:
320-
value = value << 1
321-
else:
322-
value = (value << 1) ^ (~0)
323-
return encode_varint(value)
318+
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
324319
elif proto_type in FIXED_TYPES:
325320
return struct.pack(_pack_fmt(proto_type), value)
326321
elif proto_type == TYPE_STRING:
@@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
413408
wire_type = num_wire & 0x7
414409

415410
decoded: Any = None
416-
if wire_type == 0:
411+
if wire_type == WIRE_VARINT:
417412
decoded, i = decode_varint(value, i)
418-
elif wire_type == 1:
413+
elif wire_type == WIRE_FIXED_64:
419414
decoded, i = value[i : i + 8], i + 8
420-
elif wire_type == 2:
415+
elif wire_type == WIRE_LEN_DELIM:
421416
length, i = decode_varint(value, i)
422417
decoded = value[i : i + length]
423418
i += length
424-
elif wire_type == 5:
419+
elif wire_type == WIRE_FIXED_32:
425420
decoded, i = value[i : i + 4], i + 4
426421

427422
yield ParsedField(
@@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
430425

431426

432427
class ProtoClassMetadata:
433-
oneof_group_by_field: Dict[str, str]
434-
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
435-
default_gen: Dict[str, Callable]
436-
cls_by_field: Dict[str, Type]
437-
field_name_by_number: Dict[int, str]
438-
meta_by_field_name: Dict[str, FieldMetadata]
439428
__slots__ = (
440429
"oneof_group_by_field",
441430
"oneof_field_by_group",
@@ -446,6 +435,14 @@ class ProtoClassMetadata:
446435
"sorted_field_names",
447436
)
448437

438+
oneof_group_by_field: Dict[str, str]
439+
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
440+
field_name_by_number: Dict[int, str]
441+
meta_by_field_name: Dict[str, FieldMetadata]
442+
sorted_field_names: Tuple[str, ...]
443+
default_gen: Dict[str, Callable[[], Any]]
444+
cls_by_field: Dict[str, Type]
445+
449446
def __init__(self, cls: Type["Message"]):
450447
by_field = {}
451448
by_group: Dict[str, Set] = {}
@@ -470,23 +467,21 @@ def __init__(self, cls: Type["Message"]):
470467
self.field_name_by_number = by_field_number
471468
self.meta_by_field_name = by_field_name
472469
self.sorted_field_names = tuple(
473-
by_field_number[number] for number in sorted(by_field_number.keys())
470+
by_field_number[number] for number in sorted(by_field_number)
474471
)
475-
476472
self.default_gen = self._get_default_gen(cls, fields)
477473
self.cls_by_field = self._get_cls_by_field(cls, fields)
478474

479475
@staticmethod
480-
def _get_default_gen(cls, fields):
481-
default_gen = {}
482-
483-
for field in fields:
484-
default_gen[field.name] = cls._get_field_default_gen(field)
485-
486-
return default_gen
476+
def _get_default_gen(
477+
cls: Type["Message"], fields: List[dataclasses.Field]
478+
) -> Dict[str, Callable[[], Any]]:
479+
return {field.name: cls._get_field_default_gen(field) for field in fields}
487480

488481
@staticmethod
489-
def _get_cls_by_field(cls, fields):
482+
def _get_cls_by_field(
483+
cls: Type["Message"], fields: List[dataclasses.Field]
484+
) -> Dict[str, Type]:
490485
field_cls = {}
491486

492487
for field in fields:
@@ -503,7 +498,7 @@ def _get_cls_by_field(cls, fields):
503498
],
504499
bases=(Message,),
505500
)
506-
field_cls[field.name + ".value"] = vt
501+
field_cls[f"{field.name}.value"] = vt
507502
else:
508503
field_cls[field.name] = cls._cls_for(field)
509504

@@ -612,7 +607,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
612607
super().__setattr__(attr, value)
613608

614609
@property
615-
def _betterproto(self):
610+
def _betterproto(self) -> ProtoClassMetadata:
616611
"""
617612
Lazy initialize metadata for each protobuf class.
618613
It may be initialized multiple times in a multi-threaded environment,
@@ -726,9 +721,8 @@ def _type_hint(cls, field_name: str) -> Type:
726721

727722
@classmethod
728723
def _type_hints(cls) -> Dict[str, Type]:
729-
module = inspect.getmodule(cls)
730-
type_hints = get_type_hints(cls, vars(module))
731-
return type_hints
724+
module = sys.modules[cls.__module__]
725+
return get_type_hints(cls, vars(module))
732726

733727
@classmethod
734728
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@@ -739,7 +733,7 @@ def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
739733
field_cls = field_cls.__args__[index]
740734
return field_cls
741735

742-
def _get_field_default(self, field_name):
736+
def _get_field_default(self, field_name: str) -> Any:
743737
return self._betterproto.default_gen[field_name]()
744738

745739
@classmethod
@@ -762,7 +756,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
762756
elif issubclass(t, Enum):
763757
# Enums always default to zero.
764758
return int
765-
elif t == datetime:
759+
elif t is datetime:
766760
# Offsets are relative to 1970-01-01T00:00:00Z
767761
return datetime_default_gen
768762
else:
@@ -966,7 +960,7 @@ def to_dict(
966960
)
967961
):
968962
output[cased_name] = value.to_dict(casing, include_default_values)
969-
elif meta.proto_type == "map":
963+
elif meta.proto_type == TYPE_MAP:
970964
for k in value:
971965
if hasattr(value[k], "to_dict"):
972966
value[k] = value[k].to_dict(casing, include_default_values)
@@ -1032,12 +1026,12 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10321026
continue
10331027

10341028
if value[key] is not None:
1035-
if meta.proto_type == "message":
1029+
if meta.proto_type == TYPE_MESSAGE:
10361030
v = getattr(self, field_name)
10371031
if isinstance(v, list):
10381032
cls = self._betterproto.cls_by_field[field_name]
1039-
for i in range(len(value[key])):
1040-
v.append(cls().from_dict(value[key][i]))
1033+
for item in value[key]:
1034+
v.append(cls().from_dict(item))
10411035
elif isinstance(v, datetime):
10421036
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
10431037
setattr(self, field_name, v)
@@ -1052,7 +1046,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
10521046
v.from_dict(value[key])
10531047
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
10541048
v = getattr(self, field_name)
1055-
cls = self._betterproto.cls_by_field[field_name + ".value"]
1049+
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
10561050
for k in value[key]:
10571051
v[k] = cls().from_dict(value[key][k])
10581052
else:
@@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
11341128
return message._serialized_on_wire
11351129

11361130

1137-
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
1131+
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
11381132
"""
11391133
Return the name and value of a message's one-of field group.
11401134
@@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
11451139
"""
11461140
field_name = message._group_current.get(group_name)
11471141
if not field_name:
1148-
return ("", None)
1149-
return (field_name, getattr(message, field_name))
1142+
return "", None
1143+
return field_name, getattr(message, field_name)
11501144

11511145

11521146
# Circular import workaround: google.protobuf depends on base classes defined above.
11531147
from .lib.google.protobuf import ( # noqa
1154-
Duration,
1155-
Timestamp,
11561148
BoolValue,
11571149
BytesValue,
11581150
DoubleValue,
1151+
Duration,
11591152
FloatValue,
11601153
Int32Value,
11611154
Int64Value,
11621155
StringValue,
1156+
Timestamp,
11631157
UInt32Value,
11641158
UInt64Value,
11651159
)
@@ -1174,8 +1168,8 @@ def delta_to_json(delta: timedelta) -> str:
11741168
parts = str(delta.total_seconds()).split(".")
11751169
if len(parts) > 1:
11761170
while len(parts[1]) not in [3, 6, 9]:
1177-
parts[1] = parts[1] + "0"
1178-
return ".".join(parts) + "s"
1171+
parts[1] = f"{parts[1]}0"
1172+
return f"{'.'.join(parts)}s"
11791173

11801174

11811175
class _Timestamp(Timestamp):
@@ -1191,15 +1185,15 @@ def timestamp_to_json(dt: datetime) -> str:
11911185
if (nanos % 1e9) == 0:
11921186
# If there are 0 fractional digits, the fractional
11931187
# point '.' should be omitted when serializing.
1194-
return result + "Z"
1188+
return f"{result}Z"
11951189
if (nanos % 1e6) == 0:
11961190
# Serialize 3 fractional digits.
1197-
return result + ".%03dZ" % (nanos / 1e6)
1191+
return f"{result}.{int(nanos // 1e6) :03d}Z"
11981192
if (nanos % 1e3) == 0:
11991193
# Serialize 6 fractional digits.
1200-
return result + ".%06dZ" % (nanos / 1e3)
1194+
return f"{result}.{int(nanos // 1e3) :06d}Z"
12011195
# Serialize 9 fractional digits.
1202-
return result + ".%09dZ" % nanos
1196+
return f"{result}.{nanos:09d}"
12031197

12041198

12051199
class _WrappedMessage(Message):

src/betterproto/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import TYPE_CHECKING, TypeVar
22

33
if TYPE_CHECKING:
4-
from . import Message
54
from grpclib._typing import IProtoMessage
5+
from . import Message
66

77
# Bound type variable to allow methods to return `self` of subclasses
88
T = TypeVar("T", bound="Message")

0 commit comments

Comments
 (0)