Skip to content

Commit 41a96f6

Browse files
Merge pull request #3 from danielgtaylor/imports
Implement imports, simplified default value handling
2 parents 55be5ee + 130acff commit 41a96f6

File tree

9 files changed

+258
-126
lines changed

9 files changed

+258
-126
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ jinja2 = "*"
1818
python_version = "3.7"
1919

2020
[scripts]
21+
plugin = "protoc --plugin=protoc-gen-custom=protoc-gen-betterpy.py --custom_out=output"
2122
generate = "python betterproto/tests/generate.py"
2223
test = "pytest ./betterproto/tests"

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
1313

1414
This project is heavily inspired by, and borrows functionality from:
1515

16+
- https://github.com/protocolbuffers/protobuf/tree/master/python
1617
- https://github.com/eigenein/protobuf/
1718
- https://github.com/vmagamedov/grpclib
1819

@@ -27,8 +28,8 @@ This project is heavily inspired by, and borrows functionality from:
2728
- [x] Maps
2829
- [x] Maps of message fields
2930
- [ ] Support passthrough of unknown fields
30-
- [ ] Refs to nested types
31-
- [ ] Imports in proto files
31+
- [x] Refs to nested types
32+
- [x] Imports in proto files
3233
- [ ] Well-known Google types
3334
- [ ] JSON that isn't completely naive.
3435
- [ ] Async service stubs

betterproto/__init__.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@
9292
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
9393

9494

95+
def get_default(proto_type: int) -> Any:
96+
"""Get the default (zero value) for a given type."""
97+
return {
98+
TYPE_BOOL: False,
99+
TYPE_FLOAT: 0.0,
100+
TYPE_DOUBLE: 0.0,
101+
TYPE_STRING: "",
102+
TYPE_BYTES: b"",
103+
TYPE_MAP: {},
104+
}.get(proto_type, 0)
105+
106+
95107
@dataclasses.dataclass(frozen=True)
96108
class FieldMetadata:
97109
"""Stores internal metadata used for parsing & serialization."""
@@ -114,7 +126,7 @@ def get(field: dataclasses.Field) -> "FieldMetadata":
114126
def dataclass_field(
115127
number: int,
116128
proto_type: str,
117-
default: Any,
129+
default: Any = None,
118130
map_types: Optional[Tuple[str, str]] = None,
119131
**kwargs: dict,
120132
) -> dataclasses.Field:
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
141153
return dataclass_field(number, TYPE_ENUM, default=default)
142154

143155

156+
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
157+
return dataclass_field(number, TYPE_BOOL, default=default)
158+
159+
144160
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
145161
return dataclass_field(number, TYPE_INT32, default=default)
146162

@@ -193,8 +209,12 @@ def string_field(number: int, default: str = "") -> Any:
193209
return dataclass_field(number, TYPE_STRING, default=default)
194210

195211

196-
def message_field(number: int, default: Type["Message"]) -> Any:
197-
return dataclass_field(number, TYPE_MESSAGE, default=default)
212+
def bytes_field(number: int, default: bytes = b"") -> Any:
213+
return dataclass_field(number, TYPE_BYTES, default=default)
214+
215+
216+
def message_field(number: int) -> Any:
217+
return dataclass_field(number, TYPE_MESSAGE)
198218

199219

200220
def map_field(number: int, key_type: str, value_type: str) -> Any:
@@ -345,6 +365,29 @@ class Message(ABC):
345365
to go between Python, binary and JSON protobuf message representations.
346366
"""
347367

368+
def __post_init__(self) -> None:
369+
# Set a default value for each field in the class after `__init__` has
370+
# already been run.
371+
for field in dataclasses.fields(self):
372+
meta = FieldMetadata.get(field)
373+
374+
t = self._cls_for(field, index=-1)
375+
376+
value = 0
377+
if meta.proto_type == TYPE_MAP:
378+
# Maps cannot be repeated, so we check these first.
379+
value = {}
380+
elif hasattr(t, "__args__") and len(t.__args__) == 1:
381+
# Anything else with type args is a list.
382+
value = []
383+
elif meta.proto_type == TYPE_MESSAGE:
384+
# Message means creating an instance of the right type.
385+
value = t()
386+
else:
387+
value = get_default(meta.proto_type)
388+
389+
setattr(self, field.name, value)
390+
348391
def __bytes__(self) -> bytes:
349392
"""
350393
Get the binary encoded Protobuf representation of this instance.
@@ -356,6 +399,7 @@ def __bytes__(self) -> bytes:
356399

357400
if isinstance(value, list):
358401
if not len(value):
402+
# Empty values are not serialized
359403
continue
360404

361405
if meta.proto_type in PACKED_TYPES:
@@ -371,14 +415,16 @@ def __bytes__(self) -> bytes:
371415
output += _serialize_single(meta.number, meta.proto_type, item)
372416
elif isinstance(value, dict):
373417
if not len(value):
418+
# Empty values are not serialized
374419
continue
375420

376421
for k, v in value.items():
377422
sk = _serialize_single(1, meta.map_types[0], k)
378423
sv = _serialize_single(2, meta.map_types[1], v)
379424
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
380425
else:
381-
if value == field.default:
426+
if value == get_default(meta.proto_type):
427+
# Default (zero) values are not serialized
382428
continue
383429

384430
output += _serialize_single(meta.number, meta.proto_type, value)
@@ -390,7 +436,7 @@ def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
390436
module = inspect.getmodule(self)
391437
type_hints = get_type_hints(self, vars(module))
392438
cls = type_hints[field.name]
393-
if hasattr(cls, "__args__"):
439+
if hasattr(cls, "__args__") and index >= 0:
394440
cls = type_hints[field.name].__args__[index]
395441
return cls
396442

@@ -522,7 +568,7 @@ def from_dict(self, value: dict) -> T:
522568
"""
523569
for field in dataclasses.fields(self):
524570
meta = FieldMetadata.get(field)
525-
if field.name in value:
571+
if field.name in value and value[field.name] is not None:
526572
if meta.proto_type == "message":
527573
v = getattr(self, field.name)
528574
# print(v, value[field.name])

betterproto/templates/main.py

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/ref.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"greeting": {
3+
"greeting": "hello"
4+
}
5+
}

betterproto/tests/ref.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
package ref;
4+
5+
import "repeatedmessage.proto";
6+
7+
message Test {
8+
repeatedmessage.Sub greeting = 1;
9+
}

betterproto/tests/repeatedmessage.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
syntax = "proto3";
22

3+
package repeatedmessage;
4+
35
message Test {
46
repeated Sub greetings = 1;
57
}

betterproto/tests/test_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import pytest
33
import json
44

5-
from generate import get_files, get_base
5+
from .generate import get_files, get_base
66

77
inputs = get_files(".bin")
88

99

1010
@pytest.mark.parametrize("filename", inputs)
1111
def test_sample(filename: str) -> None:
1212
module = get_base(filename).split("-")[0]
13-
imported = importlib.import_module(module)
13+
imported = importlib.import_module(f"betterproto.tests.{module}")
1414
data_binary = open(filename, "rb").read()
1515
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
1616
t1 = imported.Test().parse(data_binary)

0 commit comments

Comments
 (0)