Skip to content

Commit 4cf2e9e

Browse files
Make parse classmethod (#96)
* Make parse a class method * Fix tests * Fix typechecking
1 parent af6ac9a commit 4cf2e9e

File tree

8 files changed

+38
-50
lines changed

8 files changed

+38
-50
lines changed

src/betterproto2/__init__.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -768,14 +768,9 @@ def SerializeToString(self) -> bytes:
768768
"""
769769
return bytes(self)
770770

771-
def __getstate__(self) -> bytes:
772-
return bytes(self)
773-
774-
def __setstate__(self: T, pickled_bytes: bytes) -> T:
775-
return self.parse(pickled_bytes)
776-
777771
def __reduce__(self) -> tuple[Any, ...]:
778-
return (self.__class__.FromString, (bytes(self),))
772+
# To support pickling
773+
return (self.__class__.parse, (bytes(self),))
779774

780775
@classmethod
781776
def _type_hint(cls, field_name: str) -> type:
@@ -829,12 +824,12 @@ def _postprocess_single(self, wire_type: int, meta: FieldMetadata, field_name: s
829824
else:
830825
msg_cls = self._betterproto.cls_by_field[field_name]
831826

832-
value = msg_cls().parse(value)
827+
value = msg_cls.parse(value)
833828

834829
if meta.unwrap:
835830
value = value.to_wrapped()
836831
elif meta.proto_type == TYPE_MAP:
837-
value = self._betterproto.cls_by_field[field_name]().parse(value)
832+
value = self._betterproto.cls_by_field[field_name].parse(value)
838833

839834
return value
840835

@@ -931,7 +926,8 @@ def load(
931926

932927
return self
933928

934-
def parse(self: T, data: bytes) -> T:
929+
@classmethod
930+
def parse(cls, data: bytes) -> Self:
935931
"""
936932
Parse the binary encoded Protobuf into this message instance. This
937933
returns the instance itself and is therefore assignable and chainable.
@@ -947,7 +943,7 @@ def parse(self: T, data: bytes) -> T:
947943
The initialized message.
948944
"""
949945
with BytesIO(data) as stream:
950-
return self.load(stream)
946+
return cls().load(stream)
951947

952948
# For compatibility with other libraries.
953949
@classmethod
@@ -971,7 +967,7 @@ def FromString(cls: type[T], data: bytes) -> T:
971967
:class:`Message`
972968
The initialized message.
973969
"""
974-
return cls().parse(data)
970+
return cls.parse(data)
975971

976972
def to_dict(
977973
self,
@@ -1210,21 +1206,22 @@ def _validate_field_groups(cls, values):
12101206
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
12111207

12121208

1213-
try:
1214-
import betterproto2_rust_codec # pyright: ignore[reportMissingImports]
1215-
except ModuleNotFoundError:
1216-
pass
1217-
else:
1209+
# The Rust codec is not available for now
1210+
# try:
1211+
# import betterproto2_rust_codec # pyright: ignore[reportMissingImports]
1212+
# except ModuleNotFoundError:
1213+
# pass
1214+
# else:
12181215

1219-
def parse_patched(self, data: bytes) -> Message:
1220-
betterproto2_rust_codec.deserialize(self, data)
1221-
return self
1216+
# def parse_patched(self, data: bytes) -> Message:
1217+
# betterproto2_rust_codec.deserialize(self, data)
1218+
# return self
12221219

1223-
def bytes_patched(self) -> bytes:
1224-
return betterproto2_rust_codec.serialize(self)
1220+
# def bytes_patched(self) -> bytes:
1221+
# return betterproto2_rust_codec.serialize(self)
12251222

1226-
Message.parse = parse_patched
1227-
Message.__bytes__ = bytes_patched
1223+
# Message.parse = parse_patched
1224+
# Message.__bytes__ = bytes_patched
12281225

12291226

12301227
def which_one_of(message: Message, group_name: str) -> tuple[str, Any | None]:

tests/inputs/enum/test_enum.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def enum_generator():
7575

7676
def test_enum_mapped_on_parse():
7777
# test default value
78-
b = Test().parse(bytes(Test()))
78+
b = Test.parse(bytes(Test()))
7979
assert b.choice.name == Choice.ZERO.name
8080
assert b.choices == []
8181

8282
# test non default value
83-
a = Test().parse(bytes(Test(choice=Choice.ONE)))
83+
a = Test.parse(bytes(Test(choice=Choice.ONE)))
8484
assert a.choice.name == Choice.ONE.name
8585
assert b.choices == []
8686

8787
# test repeated
88-
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
88+
c = Test.parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
8989
assert c.choices[0].name == Choice.THREE.name
9090
assert c.choices[1].name == Choice.FOUR.name
9191

tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_bytes_are_the_same_for_oneof():
4848

4949
assert message_bytes == message_reference_bytes
5050

51-
message2 = Test().parse(message_reference_bytes)
51+
message2 = Test.parse(message_reference_bytes)
5252
message_reference2 = ReferenceTest()
5353
message_reference2.ParseFromString(message_reference_bytes)
5454

@@ -71,7 +71,7 @@ def test_datetime_clamping(dt): # see #407
7171
assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
7272
message_bytes = bytes(Spam(dt))
7373

74-
assert Spam().parse(message_bytes).ts.timestamp() == ReferenceSpam.FromString(message_bytes).ts.seconds
74+
assert Spam.parse(message_bytes).ts.timestamp() == ReferenceSpam.FromString(message_bytes).ts.seconds
7575

7676

7777
def test_empty_message_field():
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
from tests.output_betterproto.regression_387 import (
2-
ParentElement,
3-
Test,
4-
)
1+
from tests.output_betterproto.regression_387 import ParentElement, Test
52

63

74
def test_regression_387():
85
el = ParentElement(name="test", elems=[Test(id=0), Test(id=42)])
96
binary = bytes(el)
10-
decoded = ParentElement().parse(binary)
7+
decoded = ParentElement.parse(binary)
118
assert decoded == el
129
assert decoded.elems == [Test(id=0), Test(id=42)]

tests/inputs/regression_414/test_regression_414.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_full_cycle():
88

99
obj = Test(body=body, auth=auth, signatures=sig)
1010

11-
decoded = Test().parse(bytes(obj))
11+
decoded = Test.parse(bytes(obj))
1212
assert decoded == obj
1313
assert decoded.body == body
1414
assert decoded.auth == auth

tests/test_any.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def test_any() -> None:
88
any = Any()
99
any.pack(person)
1010

11-
new_any = Any().parse(bytes(any))
11+
new_any = Any.parse(bytes(any))
1212

1313
assert new_any.unpack() == person
1414

tests/test_features.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def test_unknown_fields():
4747
serialized_newer = bytes(newer)
4848

4949
# Unknown fields in `Newer` should round trip with `Older`
50-
round_trip = bytes(Older().parse(serialized_newer))
50+
round_trip = bytes(Older.parse(serialized_newer))
5151
assert serialized_newer == round_trip
5252

53-
new_again = Newer().parse(round_trip)
53+
new_again = Newer.parse(round_trip)
5454
assert newer == new_again
5555

5656

@@ -84,7 +84,7 @@ def test_oneof_support():
8484
assert bytes(msg) == b"\x08\x00"
8585

8686
# Round trip should also work
87-
msg = OneofMsg().parse(bytes(msg))
87+
msg = OneofMsg.parse(bytes(msg))
8888
assert betterproto2.which_one_of(msg, "group1")[0] == "x"
8989
assert msg.x == 0
9090
assert betterproto2.which_one_of(msg, "group2")[0] == ""
@@ -147,8 +147,8 @@ def test_optional_flag():
147147
assert bytes(OptionalBoolMsg(field=False)) == b"\n\x00"
148148

149149
# Differentiate between not passed and the zero-value.
150-
assert OptionalBoolMsg().parse(b"").field is None
151-
assert OptionalBoolMsg().parse(b"\n\x00").field is False
150+
assert OptionalBoolMsg.parse(b"").field is None
151+
assert OptionalBoolMsg.parse(b"\n\x00").field is False
152152

153153

154154
def test_optional_datetime_to_dict():
@@ -271,7 +271,7 @@ def test_oneof_default_value_set_causes_writes_wire():
271271
from tests.output_betterproto.features import Empty, MsgC
272272

273273
def _round_trip_serialization(msg: MsgC) -> MsgC:
274-
return MsgC().parse(bytes(msg))
274+
return MsgC.parse(bytes(msg))
275275

276276
int_msg = MsgC(int_field=0)
277277
str_msg = MsgC(string_field="")

tests/test_pickling.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import pickle
2-
from copy import (
3-
copy,
4-
deepcopy,
5-
)
2+
from copy import copy, deepcopy
63

74
import cachelib
85

@@ -49,10 +46,7 @@ def test_pickling_complex_message():
4946

5047

5148
def test_recursive_message_defaults():
52-
from tests.output_betterproto.recursivemessage import (
53-
Intermediate,
54-
Test as RecursiveMessage,
55-
)
49+
from tests.output_betterproto.recursivemessage import Intermediate, Test as RecursiveMessage
5650

5751
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
5852
msg = unpickled(msg)

0 commit comments

Comments
 (0)