Skip to content

Commit a5fac1c

Browse files
committed
Support pass-through of unknown fields
1 parent b5c1f1a commit a5fac1c

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ $ pipenv run tests
218218
- [x] Repeated message fields
219219
- [x] Maps
220220
- [x] Maps of message fields
221-
- [ ] Support passthrough of unknown fields
221+
- [x] Support passthrough of unknown fields
222222
- [x] Refs to nested types
223223
- [x] Imports in proto files
224224
- [x] Well-known Google types

betterproto/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,17 +341,19 @@ class ParsedField:
341341
number: int
342342
wire_type: int
343343
value: Any
344+
raw: bytes
344345

345346

346347
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
347348
i = 0
348349
while i < len(value):
350+
start = i
349351
num_wire, i = decode_varint(value, i)
350352
# print(num_wire, i)
351353
number = num_wire >> 3
352354
wire_type = num_wire & 0x7
353355

354-
decoded: Any
356+
decoded: Any = None
355357
if wire_type == 0:
356358
decoded, i = decode_varint(value, i)
357359
elif wire_type == 1:
@@ -362,12 +364,12 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
362364
i += length
363365
elif wire_type == 5:
364366
decoded, i = value[i : i + 4], i + 4
365-
else:
366-
raise NotImplementedError(f"Wire type {wire_type}")
367367

368368
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
369369

370-
yield ParsedField(number=number, wire_type=wire_type, value=decoded)
370+
yield ParsedField(
371+
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
372+
)
371373

372374

373375
# Bound type variable to allow methods to return `self` of subclasses
@@ -415,6 +417,7 @@ def __post_init__(self) -> None:
415417

416418
# Now that all the defaults are set, reset it!
417419
self.__dict__["serialized_on_wire"] = False
420+
self.__dict__["_unknown_fields"] = b""
418421

419422
def __setattr__(self, attr: str, value: Any) -> None:
420423
if attr != "serialized_on_wire":
@@ -469,7 +472,7 @@ def __bytes__(self) -> bytes:
469472
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
470473
)
471474

472-
return output
475+
return output + self._unknown_fields
473476

474477
# For compatibility with other libraries
475478
SerializeToString = __bytes__
@@ -571,8 +574,7 @@ def parse(self: T, data: bytes) -> T:
571574
else:
572575
setattr(self, field.name, value)
573576
else:
574-
# TODO: handle unknown fields
575-
pass
577+
self._unknown_fields += parsed.raw
576578

577579
return self
578580

betterproto/tests/test_features.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,25 @@ class Foo(betterproto.Message):
4848
# Plain-ol'-ints should serialize properly too.
4949
foo.bar = 1
5050
assert foo.to_dict() == {"bar": "ONE"}
51+
52+
53+
def test_unknown_fields():
54+
@dataclass
55+
class Newer(betterproto.Message):
56+
foo: bool = betterproto.bool_field(1)
57+
bar: int = betterproto.int32_field(2)
58+
baz: str = betterproto.string_field(3)
59+
60+
@dataclass
61+
class Older(betterproto.Message):
62+
foo: bool = betterproto.bool_field(1)
63+
64+
newer = Newer(foo=True, bar=1, baz="Hello")
65+
serialized_newer = bytes(newer)
66+
67+
# Unknown fields in `Newer` should round trip with `Older`
68+
round_trip = bytes(Older().parse(serialized_newer))
69+
assert serialized_newer == round_trip
70+
71+
new_again = Newer().parse(round_trip)
72+
assert newer == new_again

0 commit comments

Comments
 (0)