Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Commit 17b205d

Browse files
Remove group current and setattr (#21)
* Remove setattribute and group current * Remove useless parameter * Simplify __post_init__
1 parent a8f4887 commit 17b205d

File tree

2 files changed

+39
-121
lines changed

2 files changed

+39
-121
lines changed

src/betterproto/__init__.py

Lines changed: 24 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,6 @@ def _serialize_single(
574574
proto_type: str,
575575
value: Any,
576576
*,
577-
serialize_empty: bool = False,
578577
wraps: str = "",
579578
) -> bytes:
580579
"""Serializes a single field and value."""
@@ -591,9 +590,8 @@ def _serialize_single(
591590
key = encode_varint((field_number << 3) | 1)
592591
output += key + value
593592
elif proto_type in WIRE_LEN_DELIM_TYPES:
594-
if len(value) or serialize_empty or wraps:
595-
key = encode_varint((field_number << 3) | 2)
596-
output += key + encode_varint(len(value)) + value
593+
key = encode_varint((field_number << 3) | 2)
594+
output += key + encode_varint(len(value)) + value
597595
else:
598596
raise NotImplementedError(proto_type)
599597

@@ -842,26 +840,10 @@ class Message(ABC):
842840
"""
843841

844842
_unknown_fields: bytes
845-
_group_current: Dict[str, str]
846843
_betterproto_meta: ClassVar[ProtoClassMetadata]
847844

848845
def __post_init__(self) -> None:
849-
# # Keep track of whether every field was default
850-
# all_sentinel = True
851-
852-
# Set current field of each group after `__init__` has already been run.
853-
group_current: Dict[str, Optional[str]] = {}
854-
for field_name, meta in self._betterproto.meta_by_field_name.items():
855-
if meta.group:
856-
group_current.setdefault(meta.group)
857-
858-
value = self.__getattribute__(field_name)
859-
if value is not None:
860-
group_current[meta.group] = field_name
861-
862-
# Now that all the defaults are set, reset it!
863-
self.__dict__["_unknown_fields"] = b""
864-
self.__dict__["_group_current"] = group_current
846+
self._unknown_fields = b""
865847

866848
def __eq__(self, other) -> bool:
867849
if type(self) is not type(other):
@@ -900,18 +882,6 @@ def __repr__(self) -> str:
900882
# for field_name in self._betterproto.sorted_field_names:
901883
# yield field_name, self.__getattribute__(field_name), PLACEHOLDER
902884

903-
def __setattr__(self, attr: str, value: Any) -> None:
904-
if hasattr(self, "_group_current"): # __post_init__ had already run
905-
if attr in self._betterproto.oneof_group_by_field:
906-
group = self._betterproto.oneof_group_by_field[attr]
907-
for field in self._betterproto.oneof_field_by_group[group]:
908-
if field.name == attr:
909-
self._group_current[group] = field.name
910-
else:
911-
super().__setattr__(field.name, None)
912-
913-
super().__setattr__(attr, value)
914-
915885
def __bool__(self) -> bool:
916886
"""True if the Message has any fields with non-default values."""
917887
return any(
@@ -978,26 +948,8 @@ def __bytes__(self) -> bytes:
978948
# wrapper types and proto3 field presence/optional fields.
979949
continue
980950

981-
# Being selected in a a group means this field is the one that is
982-
# currently set in a `oneof` group, so it must be serialized even
983-
# if the value is the default zero value.
984-
#
985-
# Note that proto3 field presence/optional fields are put in a
986-
# synthetic single-item oneof by protoc, which helps us ensure we
987-
# send the value even if the value is the default zero value.
988-
selected_in_group = bool(meta.group) or meta.optional
989-
990-
include_default_value_for_oneof = self._include_default_value_for_oneof(
991-
field_name=field_name, meta=meta
992-
)
993-
994-
if value == self._get_field_default(field_name) and not (
995-
selected_in_group or include_default_value_for_oneof
996-
):
997-
# Default (zero) values are not serialized. Two exceptions are
998-
# if this is the selected oneof item or if we know we have to
999-
# serialize an empty message (i.e. zero value was explicitly
1000-
# set by the user).
951+
if value == self._get_field_default(field_name):
952+
# Default (zero) values are not serialized.
1001953
continue
1002954

1003955
if isinstance(value, list):
@@ -1017,7 +969,6 @@ def __bytes__(self) -> bytes:
1017969
meta.proto_type,
1018970
item,
1019971
wraps=meta.wraps or "",
1020-
serialize_empty=True,
1021972
)
1022973
# if it's an empty message it still needs to be represented
1023974
# as an item in the repeated list
@@ -1033,24 +984,11 @@ def __bytes__(self) -> bytes:
1033984
_serialize_single(meta.number, meta.proto_type, sk + sv)
1034985
)
1035986
else:
1036-
# If we have an empty string and we're including the default value for
1037-
# a oneof, make sure we serialize it. This ensures that the byte string
1038-
# output isn't simply an empty string. This also ensures that round trip
1039-
# serialization will keep `which_one_of` calls consistent.
1040-
serialize_empty = False
1041-
if (
1042-
isinstance(value, str)
1043-
and value == ""
1044-
and include_default_value_for_oneof
1045-
):
1046-
serialize_empty = True
1047-
1048987
stream.write(
1049988
_serialize_single(
1050989
meta.number,
1051990
meta.proto_type,
1052991
value,
1053-
serialize_empty=serialize_empty or bool(selected_in_group),
1054992
wraps=meta.wraps or "",
1055993
)
1056994
)
@@ -1109,6 +1047,9 @@ def _get_field_default(self, field_name: str) -> Any:
11091047

11101048
@classmethod
11111049
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
1050+
if field.metadata["betterproto"].optional:
1051+
return type(None)
1052+
11121053
t = cls._type_hint(field.name)
11131054

11141055
is_310_union = isinstance(t, _types_UnionType)
@@ -1181,13 +1122,6 @@ def _postprocess_single(
11811122

11821123
return value
11831124

1184-
def _include_default_value_for_oneof(
1185-
self, field_name: str, meta: FieldMetadata
1186-
) -> bool:
1187-
return (
1188-
meta.group is not None and self._group_current.get(meta.group) == field_name
1189-
)
1190-
11911125
def load(
11921126
self: T,
11931127
stream: "SupportsRead[bytes]",
@@ -1357,23 +1291,9 @@ def to_dict(
13571291
cased_name = casing(field_name).rstrip("_") # type: ignore
13581292
if meta.proto_type == TYPE_MESSAGE:
13591293
if isinstance(value, datetime):
1360-
if (
1361-
value != DATETIME_ZERO
1362-
or include_default_values
1363-
or self._include_default_value_for_oneof(
1364-
field_name=field_name, meta=meta
1365-
)
1366-
):
1367-
output[cased_name] = _Timestamp.timestamp_to_json(value)
1294+
output[cased_name] = _Timestamp.timestamp_to_json(value)
13681295
elif isinstance(value, timedelta):
1369-
if (
1370-
value != timedelta(0)
1371-
or include_default_values
1372-
or self._include_default_value_for_oneof(
1373-
field_name=field_name, meta=meta
1374-
)
1375-
):
1376-
output[cased_name] = _Duration.delta_to_json(value)
1296+
output[cased_name] = _Duration.delta_to_json(value)
13771297
elif meta.wraps:
13781298
if value is not None or include_default_values:
13791299
output[cased_name] = value
@@ -1403,13 +1323,7 @@ def to_dict(
14031323

14041324
if value or include_default_values:
14051325
output[cased_name] = output_map
1406-
elif (
1407-
value != self._get_field_default(field_name)
1408-
or include_default_values
1409-
or self._include_default_value_for_oneof(
1410-
field_name=field_name, meta=meta
1411-
)
1412-
):
1326+
elif value != self._get_field_default(field_name) or include_default_values:
14131327
if meta.proto_type in INT_64_TYPES:
14141328
if field_is_repeated:
14151329
output[cased_name] = [str(n) for n in value]
@@ -1680,13 +1594,7 @@ def to_pydict(
16801594

16811595
if value or include_default_values:
16821596
output[cased_name] = value
1683-
elif (
1684-
value != self._get_field_default(field_name)
1685-
or include_default_values
1686-
or self._include_default_value_for_oneof(
1687-
field_name=field_name, meta=meta
1688-
)
1689-
):
1597+
elif value != self._get_field_default(field_name) or include_default_values:
16901598
output[cased_name] = value
16911599
return output
16921600

@@ -1796,10 +1704,18 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]
17961704
Tuple[:class:`str`, Any]
17971705
The field name and the value for that field.
17981706
"""
1799-
field_name = message._group_current.get(group_name)
1800-
if not field_name:
1801-
return "", None
1802-
return field_name, getattr(message, field_name)
1707+
field_name, value = "", None
1708+
for field in message._betterproto.oneof_field_by_group[group_name]:
1709+
v = getattr(message, field.name)
1710+
1711+
if v is not None:
1712+
if field_name:
1713+
raise RuntimeError(
1714+
f"more than one field set in oneof: {field.name} and {field_name}"
1715+
)
1716+
field_name, value = field.name, v
1717+
1718+
return field_name, value
18031719

18041720

18051721
# Circular import workaround: google.protobuf depends on base classes defined above.

tests/test_features.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,32 @@ class Sub(betterproto.Message):
9595

9696
@dataclass
9797
class Foo(betterproto.Message):
98-
bar: int = betterproto.int32_field(1, group="group1")
99-
baz: str = betterproto.string_field(2, group="group1")
100-
sub: Sub = betterproto.message_field(3, group="group2")
101-
abc: str = betterproto.string_field(4, group="group2")
98+
bar: int = betterproto.int32_field(1, optional=True, group="group1")
99+
baz: str = betterproto.string_field(2, optional=True, group="group1")
100+
sub: Sub = betterproto.message_field(3, optional=True, group="group2")
101+
abc: str = betterproto.string_field(4, optional=True, group="group2")
102102

103103
foo = Foo()
104104

105105
assert betterproto.which_one_of(foo, "group1")[0] == ""
106106

107107
foo.bar = 1
108-
foo.baz = "test"
108+
assert betterproto.which_one_of(foo, "group1")[0] == "bar"
109109

110-
# Other oneof fields should now be unset
111-
assert foo.bar is None
110+
foo.bar = None
111+
foo.baz = "test"
112112
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
113113

114114
foo.sub = Sub(val=1)
115+
assert betterproto.which_one_of(foo, "group2")[0] == "sub"
115116

117+
foo.sub = None
116118
foo.abc = "test"
117-
118-
# Group 1 shouldn't be touched, group 2 should have reset
119-
assert foo.sub is None
120119
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
121120

121+
# Group 1 shouldn't be touched
122+
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
123+
122124
# Zero value should always serialize for one-of
123125
foo = Foo(bar=0)
124126
assert betterproto.which_one_of(foo, "group1")[0] == "bar"
@@ -419,9 +421,9 @@ class Empty(betterproto.Message):
419421

420422
@dataclass
421423
class Foo(betterproto.Message):
422-
bar: int = betterproto.int32_field(1, group="group1")
423-
baz: str = betterproto.string_field(2, group="group1")
424-
qux: Empty = betterproto.message_field(3, group="group1")
424+
bar: int = betterproto.int32_field(1, optional=True, group="group1")
425+
baz: str = betterproto.string_field(2, optional=True, group="group1")
426+
qux: Empty = betterproto.message_field(3, optional=True, group="group1")
425427

426428
def _round_trip_serialization(foo: Foo) -> Foo:
427429
return Foo().parse(bytes(foo))

0 commit comments

Comments
 (0)