Skip to content

Commit ad7162a

Browse files
committed
Support for repeated message fields
1 parent 1a488fa commit ad7162a

File tree

4 files changed

+77
-34
lines changed

4 files changed

+77
-34
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- [x] Zig-zag signed fields (sint32, sint64)
66
- [x] Don't encode zero values for nested types
77
- [x] Enums
8-
- [ ] Repeated message fields
8+
- [x] Repeated message fields
99
- [ ] Maps
1010
- [ ] Support passthrough of unknown fields
1111
- [ ] Refs to nested types

betterproto/__init__.py

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import struct
44
from typing import (
5+
get_type_hints,
56
Union,
67
Generator,
78
Any,
@@ -15,6 +16,8 @@
1516
)
1617
import dataclasses
1718

19+
import inspect
20+
1821
# Proto 3 data types
1922
TYPE_ENUM = "enum"
2023
TYPE_BOOL = "bool"
@@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
283286
raise ValueError("Too many bytes when decoding varint.")
284287

285288

286-
def _postprocess_single(
287-
wire_type: int, meta: FieldMetadata, field: Any, value: Any
288-
) -> Any:
289-
"""Adjusts values after parsing."""
290-
if wire_type == WIRE_VARINT:
291-
if meta.proto_type in ["int32", "int64"]:
292-
bits = int(meta.proto_type[3:])
293-
value = value & ((1 << bits) - 1)
294-
signbit = 1 << (bits - 1)
295-
value = int((value ^ signbit) - signbit)
296-
elif meta.proto_type in ["sint32", "sint64"]:
297-
# Undo zig-zag encoding
298-
value = (value >> 1) ^ (-(value & 1))
299-
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
300-
fmt = _pack_fmt(meta.proto_type)
301-
value = struct.unpack(fmt, value)[0]
302-
elif wire_type == WIRE_LEN_DELIM:
303-
if meta.proto_type in ["string"]:
304-
value = value.decode("utf-8")
305-
elif meta.proto_type in ["message"]:
306-
orig = value
307-
value = field.default_factory()
308-
if isinstance(value, Message):
309-
# If it's a message (instead of e.g. list) then keep going!
310-
value.parse(orig)
311-
312-
return value
313-
314-
315289
@dataclasses.dataclass(frozen=True)
316290
class ParsedField:
317291
number: int
@@ -388,6 +362,41 @@ def __bytes__(self) -> bytes:
388362

389363
return output
390364

365+
def _cls_for(self, field: dataclasses.Field) -> Type:
366+
"""Get the message class for a field from the type hints."""
367+
module = inspect.getmodule(self)
368+
type_hints = get_type_hints(self, vars(module))
369+
cls = type_hints[field.name]
370+
if hasattr(cls, "__args__"):
371+
print(type_hints[field.name].__args__[0])
372+
cls = type_hints[field.name].__args__[0]
373+
return cls
374+
375+
def _postprocess_single(
376+
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
377+
) -> Any:
378+
"""Adjusts values after parsing."""
379+
if wire_type == WIRE_VARINT:
380+
if meta.proto_type in ["int32", "int64"]:
381+
bits = int(meta.proto_type[3:])
382+
value = value & ((1 << bits) - 1)
383+
signbit = 1 << (bits - 1)
384+
value = int((value ^ signbit) - signbit)
385+
elif meta.proto_type in ["sint32", "sint64"]:
386+
# Undo zig-zag encoding
387+
value = (value >> 1) ^ (-(value & 1))
388+
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
389+
fmt = _pack_fmt(meta.proto_type)
390+
value = struct.unpack(fmt, value)[0]
391+
elif wire_type == WIRE_LEN_DELIM:
392+
if meta.proto_type in ["string"]:
393+
value = value.decode("utf-8")
394+
elif meta.proto_type in ["message"]:
395+
cls = self._cls_for(field)
396+
value = cls().parse(value)
397+
398+
return value
399+
391400
def parse(self, data: bytes) -> T:
392401
"""
393402
Parse the binary encoded Protobuf into this message instance. This
@@ -416,10 +425,12 @@ def parse(self, data: bytes) -> T:
416425
else:
417426
decoded, pos = decode_varint(parsed.value, pos)
418427
wire_type = WIRE_VARINT
419-
decoded = _postprocess_single(wire_type, meta, field, decoded)
428+
decoded = self._postprocess_single(
429+
wire_type, meta, field, decoded
430+
)
420431
value.append(decoded)
421432
else:
422-
value = _postprocess_single(
433+
value = self._postprocess_single(
423434
parsed.wire_type, meta, field, parsed.value
424435
)
425436

@@ -445,7 +456,13 @@ def to_dict(self) -> dict:
445456
meta = FieldMetadata.get(field)
446457
v = getattr(self, field.name)
447458
if meta.proto_type == "message":
448-
v = v.to_dict()
459+
if isinstance(v, list):
460+
# Convert each item.
461+
v = [i.to_dict() for i in v]
462+
# Filter out empty items which we won't serialize.
463+
v = [i for i in v if i]
464+
else:
465+
v = v.to_dict()
449466
if v:
450467
output[field.name] = v
451468
elif v != field.default:
@@ -461,7 +478,14 @@ def from_dict(self, value: dict) -> T:
461478
meta = FieldMetadata.get(field)
462479
if field.name in value:
463480
if meta.proto_type == "message":
464-
getattr(self, field.name).from_dict(value[field.name])
481+
v = getattr(self, field.name)
482+
print(v, value[field.name])
483+
if isinstance(v, list):
484+
cls = self._cls_for(field)
485+
for i in range(len(value[field.name])):
486+
v.append(cls().from_dict(value[field.name][i]))
487+
else:
488+
v.from_dict(value[field.name])
465489
else:
466490
setattr(self, field.name, value[field.name])
467491
return self
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"greetings": [
3+
{
4+
"greeting": "hello"
5+
},
6+
{
7+
"greeting": "hi"
8+
}
9+
]
10+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
repeated Sub greetings = 1;
5+
}
6+
7+
message Sub {
8+
string greeting = 1;
9+
}

0 commit comments

Comments
 (0)