Skip to content

Commit 811b54c

Browse files
committed
Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes
1 parent bbceff9 commit 811b54c

File tree

11 files changed

+132
-55
lines changed

11 files changed

+132
-55
lines changed

betterproto/__init__.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
from abc import ABC
1+
import dataclasses
2+
import inspect
23
import json
34
import struct
5+
from abc import ABC
46
from typing import (
5-
get_type_hints,
7+
Any,
68
AsyncGenerator,
7-
Union,
9+
Callable,
10+
Dict,
811
Generator,
9-
Any,
10-
SupportsBytes,
12+
Iterable,
1113
List,
14+
Optional,
15+
SupportsBytes,
1216
Tuple,
13-
Callable,
1417
Type,
15-
Iterable,
1618
TypeVar,
17-
Optional,
19+
Union,
20+
get_type_hints,
1821
)
19-
import dataclasses
2022

2123
import grpclib.client
2224
import grpclib.const
2325

24-
import inspect
25-
2626
# Proto 3 data types
2727
TYPE_ENUM = "enum"
2828
TYPE_BOOL = "bool"
@@ -54,6 +54,9 @@
5454
TYPE_SFIXED64,
5555
]
5656

57+
# Fields that are numerical 64-bit types
58+
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
59+
5760
# Fields that are efficiently packed when
5861
PACKED_TYPES = [
5962
TYPE_ENUM,
@@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
275278
return value
276279

277280

278-
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
281+
def _serialize_single(
282+
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
283+
) -> bytes:
279284
"""Serializes a single field and value."""
280285
value = _preprocess_single(proto_type, value)
281286

@@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
290295
key = encode_varint((field_number << 3) | 1)
291296
output += key + value
292297
elif proto_type in WIRE_LEN_DELIM_TYPES:
293-
if len(value):
298+
if len(value) or serialize_empty:
294299
key = encode_varint((field_number << 3) | 2)
295300
output += key + encode_varint(len(value)) + value
296301
else:
@@ -362,6 +367,11 @@ class Message(ABC):
362367
to go between Python, binary and JSON protobuf message representations.
363368
"""
364369

370+
# True if this message was or should be serialized on the wire. This can
371+
# be used to detect presence (e.g. optional wrapper message) and is used
372+
# internally during parsing/serialization.
373+
serialized_on_wire: bool
374+
365375
def __post_init__(self) -> None:
366376
# Set a default value for each field in the class after `__init__` has
367377
# already been run.
@@ -389,6 +399,15 @@ def __post_init__(self) -> None:
389399

390400
setattr(self, field.name, value)
391401

402+
# Now that all the defaults are set, reset it!
403+
self.__dict__["serialized_on_wire"] = False
404+
405+
def __setattr__(self, attr: str, value: Any) -> None:
406+
if attr != "serialized_on_wire":
407+
# Track when a field has been set.
408+
self.__dict__["serialized_on_wire"] = True
409+
super().__setattr__(attr, value)
410+
392411
def __bytes__(self) -> bytes:
393412
"""
394413
Get the binary encoded Protobuf representation of this instance.
@@ -429,7 +448,12 @@ def __bytes__(self) -> bytes:
429448
# Default (zero) values are not serialized
430449
continue
431450

432-
output += _serialize_single(meta.number, meta.proto_type, value)
451+
serialize_empty = False
452+
if isinstance(value, Message) and value.serialized_on_wire:
453+
serialize_empty = True
454+
output += _serialize_single(
455+
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
456+
)
433457

434458
return output
435459

@@ -462,12 +486,13 @@ def _postprocess_single(
462486
fmt = _pack_fmt(meta.proto_type)
463487
value = struct.unpack(fmt, value)[0]
464488
elif wire_type == WIRE_LEN_DELIM:
465-
if meta.proto_type in [TYPE_STRING]:
489+
if meta.proto_type == TYPE_STRING:
466490
value = value.decode("utf-8")
467-
elif meta.proto_type in [TYPE_MESSAGE]:
491+
elif meta.proto_type == TYPE_MESSAGE:
468492
cls = self._cls_for(field)
469493
value = cls().parse(value)
470-
elif meta.proto_type in [TYPE_MAP]:
494+
value.serialized_on_wire = True
495+
elif meta.proto_type == TYPE_MAP:
471496
# TODO: This is slow, use a cache to make it faster since each
472497
# key/value pair will recreate the class.
473498
assert meta.map_types
@@ -535,8 +560,6 @@ def parse(self: T, data: bytes) -> T:
535560
# TODO: handle unknown fields
536561
pass
537562

538-
from typing import cast
539-
540563
return self
541564

542565
# For compatibility with other libraries.
@@ -549,21 +572,17 @@ def to_dict(self) -> dict:
549572
Returns a dict representation of this message instance which can be
550573
used to serialize to e.g. JSON.
551574
"""
552-
output = {}
575+
output: Dict[str, Any] = {}
553576
for field in dataclasses.fields(self):
554577
meta = FieldMetadata.get(field)
555578
v = getattr(self, field.name)
556579
if meta.proto_type == "message":
557580
if isinstance(v, list):
558581
# Convert each item.
559582
v = [i.to_dict() for i in v]
560-
# Filter out empty items which we won't serialize.
561-
v = [i for i in v if i]
562-
else:
563-
v = v.to_dict()
564-
565-
if v:
566583
output[field.name] = v
584+
elif v.serialized_on_wire:
585+
output[field.name] = v.to_dict()
567586
elif meta.proto_type == "map":
568587
for k in v:
569588
if hasattr(v[k], "to_dict"):
@@ -572,14 +591,21 @@ def to_dict(self) -> dict:
572591
if v:
573592
output[field.name] = v
574593
elif v != get_default(meta.proto_type):
575-
output[field.name] = v
594+
if meta.proto_type in INT_64_TYPES:
595+
if isinstance(v, list):
596+
output[field.name] = [str(n) for n in v]
597+
else:
598+
output[field.name] = str(v)
599+
else:
600+
output[field.name] = v
576601
return output
577602

578603
def from_dict(self: T, value: dict) -> T:
579604
"""
580605
Parse the key/value pairs in `value` into this message instance. This
581606
returns the instance itself and is therefore assignable and chainable.
582607
"""
608+
self.serialized_on_wire = True
583609
for field in dataclasses.fields(self):
584610
meta = FieldMetadata.get(field)
585611
if field.name in value and value[field.name] is not None:
@@ -598,7 +624,13 @@ def from_dict(self: T, value: dict) -> T:
598624
for k in value[field.name]:
599625
v[k] = cls().from_dict(value[field.name][k])
600626
else:
601-
setattr(self, field.name, value[field.name])
627+
v = value[field.name]
628+
if meta.proto_type in INT_64_TYPES:
629+
if isinstance(value[field.name], list):
630+
v = [int(n) for n in value[field.name]]
631+
else:
632+
v = int(value[field.name])
633+
setattr(self, field.name, v)
602634
return self
603635

604636
def to_json(self) -> str:
@@ -613,9 +645,6 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
613645
return self.from_dict(json.loads(value))
614646

615647

616-
ResponseType = TypeVar("ResponseType", bound="Message")
617-
618-
619648
class ServiceStub(ABC):
620649
"""
621650
Base class for async gRPC service stubs.

betterproto/tests/generate.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
#!/usr/bin/env python
2+
import importlib
3+
import json
24
import os # isort: skip
5+
import subprocess
6+
import sys
7+
from typing import Generator, Tuple
8+
9+
from google.protobuf import symbol_database
10+
from google.protobuf.descriptor_pool import DescriptorPool
11+
from google.protobuf.json_format import MessageToJson, Parse
312

413
# Force pure-python implementation instead of C++, otherwise imports
514
# break things because we can't properly reset the symbol database.
615
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
716

817

9-
import subprocess
10-
import importlib
11-
import sys
12-
from typing import Generator, Tuple
1318

14-
from google.protobuf.json_format import Parse
15-
from google.protobuf import symbol_database
16-
from google.protobuf.descriptor_pool import DescriptorPool
1719

1820
root = os.path.dirname(os.path.realpath(__file__))
1921

@@ -68,5 +70,10 @@ def ensure_ext(filename: str, ext: str) -> str:
6870
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
6971

7072
imported = importlib.import_module(f"{parts[0]}_pb2")
71-
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
73+
parsed = Parse(open(filename).read(), imported.Test())
74+
serialized = parsed.SerializeToString()
75+
serialized_json = MessageToJson(
76+
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
77+
)
78+
assert json.loads(serialized_json) == json.load(open(filename))
7279
open(out, "wb").write(serialized)

betterproto/tests/nested.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"nested": {
33
"count": 150
4-
}
4+
},
5+
"sibling": {}
56
}

betterproto/tests/nested.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ message Test {
1010

1111
Nested nested = 1;
1212
Sibling sibling = 2;
13+
Sibling sibling2 = 3;
1314
}
1415

1516
message Sibling {
1617
int32 foo = 1;
17-
}
18+
}

betterproto/tests/repeatedpacked.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
22
"counts": [1, 2, -1, -2],
3-
"signed": [1, 2, -1, -2],
3+
"signed": ["1", "2", "-1", "-2"],
44
"fixed": [1.0, 2.7, 3.4]
55
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"signed_32": -150,
3-
"signed_64": -150
3+
"signed_64": "-150"
44
}

betterproto/tests/signed.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"signed_32": 150,
3-
"signed_64": 150
3+
"signed_64": "150"
44
}

betterproto/tests/test_features.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import betterproto
2+
from dataclasses import dataclass
3+
4+
5+
def test_has_field():
6+
@dataclass
7+
class Bar(betterproto.Message):
8+
baz: int = betterproto.int32_field(1)
9+
10+
@dataclass
11+
class Foo(betterproto.Message):
12+
bar: Bar = betterproto.message_field(1)
13+
14+
# Unset by default
15+
foo = Foo()
16+
assert foo.bar.serialized_on_wire == False
17+
18+
# Serialized after setting something
19+
foo.bar.baz = 1
20+
assert foo.bar.serialized_on_wire == True
21+
22+
# Still has it after setting the default value
23+
foo.bar.baz = 0
24+
assert foo.bar.serialized_on_wire == True
25+
26+
# Manual override
27+
foo.bar.serialized_on_wire = False
28+
assert foo.bar.serialized_on_wire == False
29+
30+
# Can manually set it but defaults to false
31+
foo.bar = Bar()
32+
assert foo.bar.serialized_on_wire == False

betterproto/tests/test_inputs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import importlib
2-
import pytest
32
import json
43

5-
from .generate import get_files, get_base
4+
import pytest
5+
6+
from .generate import get_base, get_files
67

78
inputs = get_files(".bin")
89

protoc-gen-betterpy.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
#!/usr/bin/env python
22

3-
import sys
4-
53
import itertools
64
import json
75
import os.path
86
import re
9-
from typing import Tuple, Any, List
7+
import sys
108
import textwrap
9+
from typing import Any, List, Tuple
1110

11+
from jinja2 import Environment, PackageLoader
12+
13+
from google.protobuf.compiler import plugin_pb2 as plugin
1214
from google.protobuf.descriptor_pb2 import (
1315
DescriptorProto,
1416
EnumDescriptorProto,
15-
FileDescriptorProto,
1617
FieldDescriptorProto,
18+
FileDescriptorProto,
1719
ServiceDescriptorProto,
1820
)
1921

20-
from google.protobuf.compiler import plugin_pb2 as plugin
21-
22-
23-
from jinja2 import Environment, PackageLoader
24-
2522

2623
def snake_case(value: str) -> str:
2724
return (

0 commit comments

Comments
 (0)