Skip to content

Commit dcb7102

Browse files
committed
Implement imports, simplified default value handling
1 parent 55be5ee commit dcb7102

File tree

9 files changed

+235
-126
lines changed

9 files changed

+235
-126
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ jinja2 = "*"
1818
python_version = "3.7"
1919

2020
[scripts]
21+
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=."
2122
generate = "python betterproto/tests/generate.py"
2223
test = "pytest ./betterproto/tests"

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
1313

1414
This project is heavily inspired by, and borrows functionality from:
1515

16+
- https://github.com/protocolbuffers/protobuf/tree/master/python
1617
- https://github.com/eigenein/protobuf/
1718
- https://github.com/vmagamedov/grpclib
1819

@@ -27,8 +28,8 @@ This project is heavily inspired by, and borrows functionality from:
2728
- [x] Maps
2829
- [x] Maps of message fields
2930
- [ ] Support passthrough of unknown fields
30-
- [ ] Refs to nested types
31-
- [ ] Imports in proto files
31+
- [x] Refs to nested types
32+
- [x] Imports in proto files
3233
- [ ] Well-known Google types
3334
- [ ] JSON that isn't completely naive.
3435
- [ ] Async service stubs

betterproto/__init__.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@
9292
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
9393

9494

95+
def get_default(proto_type: int) -> Any:
96+
"""Get the default (zero value) for a given type."""
97+
return {
98+
TYPE_BOOL: False,
99+
TYPE_FLOAT: 0.0,
100+
TYPE_DOUBLE: 0.0,
101+
TYPE_STRING: "",
102+
TYPE_BYTES: b"",
103+
TYPE_MAP: {},
104+
}.get(proto_type, 0)
105+
106+
95107
@dataclasses.dataclass(frozen=True)
96108
class FieldMetadata:
97109
"""Stores internal metadata used for parsing & serialization."""
@@ -114,7 +126,7 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
114126
def dataclass_field(
115127
number: int,
116128
proto_type: str,
117-
default: Any,
129+
default: Any = None,
118130
map_types: Optional[Tuple[str, str]] = None,
119131
**kwargs: dict,
120132
) -> dataclasses.Field:
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
141153
return dataclass_field(number, TYPE_ENUM, default=default)
142154

143155

156+
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
157+
return dataclass_field(number, TYPE_BOOL, default=default)
158+
159+
144160
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
145161
return dataclass_field(number, TYPE_INT32, default=default)
146162

@@ -193,8 +209,8 @@ def string_field(number: int, default: str = "") -> Any:
193209
return dataclass_field(number, TYPE_STRING, default=default)
194210

195211

196-
def message_field(number: int, default: Type["Message"]) -> Any:
197-
return dataclass_field(number, TYPE_MESSAGE, default=default)
212+
def message_field(number: int) -> Any:
213+
return dataclass_field(number, TYPE_MESSAGE)
198214

199215

200216
def map_field(number: int, key_type: str, value_type: str) -> Any:
@@ -345,6 +361,29 @@ class Message(ABC):
345361
to go between Python, binary and JSON protobuf message representations.
346362
"""
347363

364+
def __post_init__(self) -> None:
365+
# Set a default value for each field in the class after `__init__` has
366+
# already been run.
367+
for field in dataclasses.fields(self):
368+
meta = FieldMetadata.get(field)
369+
370+
t = self._cls_for(field, index=-1)
371+
372+
value = 0
373+
if meta.proto_type == TYPE_MAP:
374+
# Maps cannot be repeated, so we check these first.
375+
value = {}
376+
elif hasattr(t, "__args__") and len(t.__args__) == 1:
377+
# Anything else with type args is a list.
378+
value = []
379+
elif meta.proto_type == TYPE_MESSAGE:
380+
# Message means creating an instance of the right type.
381+
value = t()
382+
else:
383+
value = get_default(meta.proto_type)
384+
385+
setattr(self, field.name, value)
386+
348387
def __bytes__(self) -> bytes:
349388
"""
350389
Get the binary encoded Protobuf representation of this instance.
@@ -356,6 +395,7 @@ def __bytes__(self) -> bytes:
356395

357396
if isinstance(value, list):
358397
if not len(value):
398+
# Empty values are not serialized
359399
continue
360400

361401
if meta.proto_type in PACKED_TYPES:
@@ -371,14 +411,16 @@ def __bytes__(self) -> bytes:
371411
output += _serialize_single(meta.number, meta.proto_type, item)
372412
elif isinstance(value, dict):
373413
if not len(value):
414+
# Empty values are not serialized
374415
continue
375416

376417
for k, v in value.items():
377418
sk = _serialize_single(1, meta.map_types[0], k)
378419
sv = _serialize_single(2, meta.map_types[1], v)
379420
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
380421
else:
381-
if value == field.default:
422+
if value == get_default(meta.proto_type):
423+
# Default (zero) values are not serialized
382424
continue
383425

384426
output += _serialize_single(meta.number, meta.proto_type, value)
@@ -390,7 +432,7 @@ def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
390432
module = inspect.getmodule(self)
391433
type_hints = get_type_hints(self, vars(module))
392434
cls = type_hints[field.name]
393-
if hasattr(cls, "__args__"):
435+
if hasattr(cls, "__args__") and index >= 0:
394436
cls = type_hints[field.name].__args__[index]
395437
return cls
396438

@@ -522,7 +564,7 @@ def from_dict(self, value: dict) -> T:
522564
"""
523565
for field in dataclasses.fields(self):
524566
meta = FieldMetadata.get(field)
525-
if field.name in value:
567+
if field.name in value and value[field.name] is not None:
526568
if meta.proto_type == "message":
527569
v = getattr(self, field.name)
528570
# print(v, value[field.name])

betterproto/templates/main.py

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/ref.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"greeting": {
3+
"greeting": "hello"
4+
}
5+
}

betterproto/tests/ref.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
package ref;
4+
5+
import "repeatedmessage.proto";
6+
7+
message Test {
8+
repeatedmessage.Sub greeting = 1;
9+
}

betterproto/tests/repeatedmessage.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
syntax = "proto3";
22

3+
package repeatedmessage;
4+
35
message Test {
46
repeated Sub greetings = 1;
57
}

betterproto/tests/test_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import pytest
33
import json
44

5-
from generate import get_files, get_base
5+
from .generate import get_files, get_base
66

77
inputs = get_files(".bin")
88

99

1010
@pytest.mark.parametrize("filename", inputs)
1111
def test_sample(filename: str) -> None:
1212
module = get_base(filename).split("-")[0]
13-
imported = importlib.import_module(module)
13+
imported = importlib.import_module(f"betterproto.tests.{module}")
1414
data_binary = open(filename, "rb").read()
1515
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
1616
t1 = imported.Test().parse(data_binary)

0 commit comments

Comments
 (0)