Skip to content

Commit 3d001a2

Browse files
committed
Store the class metadata of fields in the class, to improve preformance
Cached data include, - lookup table between groups and fields of "oneof" fields - default value creator of each field - type hint of each field
1 parent de61dda commit 3d001a2

File tree

3 files changed

+117
-64
lines changed

3 files changed

+117
-64
lines changed

betterproto/__init__.py

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@
120120

121121

122122
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
123-
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
123+
def datetime_default_gen():
124+
return datetime(1970, 1, 1, tzinfo=timezone.utc)
125+
126+
127+
DATETIME_ZERO = datetime_default_gen()
124128

125129

126130
class Casing(enum.Enum):
@@ -428,6 +432,57 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
428432
T = TypeVar("T", bound="Message")
429433

430434

435+
class ProtoClassMetadata:
436+
cls: "Message"
437+
438+
def __init__(self, cls: "Message"):
439+
self.cls = cls
440+
by_field = {}
441+
by_group = {}
442+
443+
for field in dataclasses.fields(cls):
444+
meta = FieldMetadata.get(field)
445+
446+
if meta.group:
447+
# This is part of a one-of group.
448+
by_field[field.name] = meta.group
449+
450+
by_group.setdefault(meta.group, set()).add(field)
451+
452+
self.oneof_group_by_field = by_field
453+
self.oneof_field_by_group = by_group
454+
455+
def __getattr__(self, item):
456+
# Lazy init because forward reference classes may not be available at the beginning.
457+
if item == 'default_gen':
458+
defaults = {}
459+
for field in dataclasses.fields(self.cls):
460+
meta = FieldMetadata.get(field)
461+
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
462+
463+
self.default_gen = defaults # __getattr__ won't be called next time
464+
return defaults
465+
466+
if item == 'cls_by_field':
467+
field_cls = {}
468+
for field in dataclasses.fields(self.cls):
469+
meta = FieldMetadata.get(field)
470+
field_cls[field.name] = self.cls._type_hint(field.name)
471+
472+
self.cls_by_field = field_cls # __getattr__ won't be called next time
473+
return field_cls
474+
475+
476+
def make_protoclass(cls):
477+
setattr(cls, "_betterproto", ProtoClassMetadata(cls))
478+
479+
480+
def protoclass(*args, **kwargs):
481+
cls = dataclasses.dataclass(*args, **kwargs)
482+
make_protoclass(cls)
483+
return cls
484+
485+
431486
class Message(ABC):
432487
"""
433488
A protobuf message base class. Generated code will inherit from this and
@@ -445,25 +500,20 @@ def __post_init__(self) -> None:
445500

446501
# Set a default value for each field in the class after `__init__` has
447502
# already been run.
448-
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
503+
group_map: Dict[str, dataclasses.Field] = {}
449504
for field in dataclasses.fields(self):
450505
meta = FieldMetadata.get(field)
451506

452507
if meta.group:
453-
# This is part of a one-of group.
454-
group_map["fields"][field.name] = meta.group
455-
456-
if meta.group not in group_map["groups"]:
457-
group_map["groups"][meta.group] = {"current": None, "fields": set()}
458-
group_map["groups"][meta.group]["fields"].add(field)
508+
group_map.setdefault(meta.group)
459509

460510
if getattr(self, field.name) != PLACEHOLDER:
461511
# Skip anything not set to the sentinel value
462512
all_sentinel = False
463513

464514
if meta.group:
465515
# This was set, so make it the selected value of the one-of.
466-
group_map["groups"][meta.group]["current"] = field
516+
group_map[meta.group] = field
467517

468518
continue
469519

@@ -479,16 +529,17 @@ def __setattr__(self, attr: str, value: Any) -> None:
479529
# Track when a field has been set.
480530
self.__dict__["_serialized_on_wire"] = True
481531

482-
if attr in getattr(self, "_group_map", {}).get("fields", {}):
483-
group = self._group_map["fields"][attr]
484-
for field in self._group_map["groups"][group]["fields"]:
485-
if field.name == attr:
486-
self._group_map["groups"][group]["current"] = field
487-
else:
488-
super().__setattr__(
489-
field.name,
490-
self._get_field_default(field, FieldMetadata.get(field)),
491-
)
532+
if hasattr(self, "_group_map"): # __post_init__ had already run
533+
if attr in self._betterproto.oneof_group_by_field:
534+
group = self._betterproto.oneof_group_by_field[attr]
535+
for field in self._betterproto.oneof_field_by_group[group]:
536+
if field.name == attr:
537+
self._group_map[group] = field
538+
else:
539+
super().__setattr__(
540+
field.name,
541+
self._get_field_default(field, FieldMetadata.get(field)),
542+
)
492543

493544
super().__setattr__(attr, value)
494545

@@ -510,7 +561,7 @@ def __bytes__(self) -> bytes:
510561
# currently set in a `oneof` group, so it must be serialized even
511562
# if the value is the default zero value.
512563
selected_in_group = False
513-
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
564+
if meta.group and self._group_map[meta.group] == field:
514565
selected_in_group = True
515566

516567
serialize_empty = False
@@ -562,47 +613,49 @@ def __bytes__(self) -> bytes:
562613
# For compatibility with other libraries
563614
SerializeToString = __bytes__
564615

565-
def _type_hint(self, field_name: str) -> Type:
566-
module = inspect.getmodule(self.__class__)
567-
type_hints = get_type_hints(self.__class__, vars(module))
616+
@classmethod
617+
def _type_hint(cls, field_name: str) -> Type:
618+
module = inspect.getmodule(cls)
619+
type_hints = get_type_hints(cls, vars(module))
568620
return type_hints[field_name]
569621

570622
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
571623
"""Get the message class for a field from the type hints."""
572-
cls = self._type_hint(field.name)
624+
cls = self._betterproto.cls_by_field[field.name]
573625
if hasattr(cls, "__args__") and index >= 0:
574626
cls = cls.__args__[index]
575627
return cls
576628

577629
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
578-
t = self._type_hint(field.name)
630+
return self._betterproto.default_gen[field.name]()
631+
632+
@classmethod
633+
def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any:
634+
t = cls._type_hint(field.name)
579635

580-
value: Any = 0
581636
if hasattr(t, "__origin__"):
582637
if t.__origin__ in (dict, Dict):
583638
# This is some kind of map (dict in Python).
584-
value = {}
639+
return dict
585640
elif t.__origin__ in (list, List):
586641
# This is some kind of list (repeated) field.
587-
value = []
642+
return list
588643
elif t.__origin__ == Union and t.__args__[1] == type(None):
589644
# This is an optional (wrapped) field. For setting the default we
590645
# really don't care what kind of field it is.
591-
value = None
646+
return type(None)
592647
else:
593-
value = t()
648+
return t
594649
elif issubclass(t, Enum):
595650
# Enums always default to zero.
596-
value = 0
651+
return int
597652
elif t == datetime:
598653
# Offsets are relative to 1970-01-01T00:00:00Z
599-
value = DATETIME_ZERO
654+
return datetime_default_gen
600655
else:
601656
# This is either a primitive scalar or another message type. Calling
602657
# it should result in its zero value.
603-
value = t()
604-
605-
return value
658+
return t
606659

607660
def _postprocess_single(
608661
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
@@ -654,6 +707,7 @@ def _postprocess_single(
654707
],
655708
bases=(Message,),
656709
)
710+
make_protoclass(Entry)
657711
value = Entry().parse(value)
658712

659713
return value
@@ -861,13 +915,13 @@ def serialized_on_wire(message: Message) -> bool:
861915

862916
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
863917
"""Return the name and value of a message's one-of field group."""
864-
field = message._group_map["groups"].get(group_name, {}).get("current")
918+
field = message._group_map.get(group_name)
865919
if not field:
866920
return ("", None)
867921
return (field.name, getattr(message, field.name))
868922

869923

870-
@dataclasses.dataclass
924+
@protoclass
871925
class _Duration(Message):
872926
# Signed seconds of the span of time. Must be from -315,576,000,000 to
873927
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@@ -892,7 +946,7 @@ def delta_to_json(delta: timedelta) -> str:
892946
return ".".join(parts) + "s"
893947

894948

895-
@dataclasses.dataclass
949+
@protoclass
896950
class _Timestamp(Message):
897951
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
898952
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@@ -942,47 +996,47 @@ def from_dict(self: T, value: Any) -> T:
942996
return self
943997

944998

945-
@dataclasses.dataclass
999+
@protoclass
9461000
class _BoolValue(_WrappedMessage):
9471001
value: bool = bool_field(1)
9481002

9491003

950-
@dataclasses.dataclass
1004+
@protoclass
9511005
class _Int32Value(_WrappedMessage):
9521006
value: int = int32_field(1)
9531007

9541008

955-
@dataclasses.dataclass
1009+
@protoclass
9561010
class _UInt32Value(_WrappedMessage):
9571011
value: int = uint32_field(1)
9581012

9591013

960-
@dataclasses.dataclass
1014+
@protoclass
9611015
class _Int64Value(_WrappedMessage):
9621016
value: int = int64_field(1)
9631017

9641018

965-
@dataclasses.dataclass
1019+
@protoclass
9661020
class _UInt64Value(_WrappedMessage):
9671021
value: int = uint64_field(1)
9681022

9691023

970-
@dataclasses.dataclass
1024+
@protoclass
9711025
class _FloatValue(_WrappedMessage):
9721026
value: float = float_field(1)
9731027

9741028

975-
@dataclasses.dataclass
1029+
@protoclass
9761030
class _DoubleValue(_WrappedMessage):
9771031
value: float = double_field(1)
9781032

9791033

980-
@dataclasses.dataclass
1034+
@protoclass
9811035
class _StringValue(_WrappedMessage):
9821036
value: str = string_field(1)
9831037

9841038

985-
@dataclasses.dataclass
1039+
@protoclass
9861040
class _BytesValue(_WrappedMessage):
9871041
value: bytes = bytes_field(1)
9881042

betterproto/templates/template.py

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)