Skip to content

Commit b5c1f1a

Browse files
committed
Support JSON base64 bytes and enums as strings
1 parent 7fe64ad commit b5c1f1a

File tree

8 files changed

+81
-13
lines changed

8 files changed

+81
-13
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ Sometimes it is useful to be able to determine whether a message has been sent o
169169
Use `Message().serialized_on_wire` to determine if it was sent. This is a little bit different from the official Google generated Python code:
170170

171171
```py
172-
# Old way
172+
# Old way (official Google Protobuf package)
173173
>>> mymessage.HasField('myfield')
174174

175-
# New way
175+
# New way (this project)
176176
>>> mymessage.myfield.serialized_on_wire
177177
```
178178

@@ -226,8 +226,9 @@ $ pipenv run tests
226226
- [x] 64-bit ints as strings
227227
- [x] Maps
228228
- [x] Lists
229-
- [ ] Bytes as base64
229+
- [x] Bytes as base64
230230
- [ ] Any support
231+
- [x] Enum strings
231232
- [ ] Well known types support (timestamp, duration, wrappers)
232233
- [ ] Async service stubs
233234
- [x] Unary-unary

betterproto/__init__.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import dataclasses
2+
import enum
23
import inspect
34
import json
45
import struct
56
from abc import ABC
7+
from base64 import b64encode, b64decode
68
from typing import (
79
Any,
810
AsyncGenerator,
@@ -222,6 +224,18 @@ def map_field(number: int, key_type: str, value_type: str) -> Any:
222224
return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
223225

224226

227+
class Enum(int, enum.Enum):
228+
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
229+
230+
@classmethod
231+
def from_string(cls, name: str) -> int:
232+
"""Return the value which corresponds to the string name."""
233+
try:
234+
return cls.__members__[name]
235+
except KeyError as e:
236+
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
237+
238+
225239
def _pack_fmt(proto_type: str) -> str:
226240
"""Returns a little-endian format string for reading/writing binary."""
227241
return {
@@ -596,6 +610,17 @@ def to_dict(self) -> dict:
596610
output[field.name] = [str(n) for n in v]
597611
else:
598612
output[field.name] = str(v)
613+
elif meta.proto_type == TYPE_BYTES:
614+
if isinstance(v, list):
615+
output[field.name] = [b64encode(b).decode("utf8") for b in v]
616+
else:
617+
output[field.name] = b64encode(v).decode("utf8")
618+
elif meta.proto_type == TYPE_ENUM:
619+
enum_values = list(self._cls_for(field))
620+
if isinstance(v, list):
621+
output[field.name] = [enum_values[e].name for e in v]
622+
else:
623+
output[field.name] = enum_values[v].name
599624
else:
600625
output[field.name] = v
601626
return output
@@ -630,7 +655,20 @@ def from_dict(self: T, value: dict) -> T:
630655
v = [int(n) for n in value[field.name]]
631656
else:
632657
v = int(value[field.name])
633-
setattr(self, field.name, v)
658+
elif meta.proto_type == TYPE_BYTES:
659+
if isinstance(value[field.name], list):
660+
v = [b64decode(n) for n in value[field.name]]
661+
else:
662+
v = b64decode(value[field.name])
663+
elif meta.proto_type == TYPE_ENUM:
664+
enum_cls = self._cls_for(field)
665+
if isinstance(v, list):
666+
v = [enum_cls.from_string(e) for e in v]
667+
elif isinstance(v, str):
668+
v = enum_cls.from_string(v)
669+
670+
if v is not None:
671+
setattr(self, field.name, v)
634672
return self
635673

636674
def to_json(self, indent: Union[None, int, str] = None) -> str:

betterproto/templates/template.py

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

betterproto/tests/bytes.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"data": "SGVsbG8sIFdvcmxkIQ=="
3+
}

betterproto/tests/bytes.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
syntax = "proto3";
2+
3+
message Test {
4+
bytes data = 1;
5+
}

betterproto/tests/enums.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"greeting": 1
2+
"greeting": "HEY"
33
}

betterproto/tests/generate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,15 @@ def ensure_ext(filename: str, ext: str) -> str:
6969
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
7070

7171
imported = importlib.import_module(f"{parts[0]}_pb2")
72-
parsed = Parse(open(filename).read(), imported.Test())
72+
input_json = open(filename).read()
73+
parsed = Parse(input_json, imported.Test())
7374
serialized = parsed.SerializeToString()
74-
serialized_json = MessageToJson(
75-
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
76-
)
77-
assert json.loads(serialized_json) == json.load(open(filename))
75+
serialized_json = MessageToJson(parsed, preserving_proto_field_name=True)
76+
77+
s_loaded = json.loads(serialized_json)
78+
in_loaded = json.loads(input_json)
79+
80+
if s_loaded != in_loaded:
81+
raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded)
82+
7883
open(out, "wb").write(serialized)

betterproto/tests/test_features.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,21 @@ class Foo(betterproto.Message):
3030
# Can manually set it but defaults to false
3131
foo.bar = Bar()
3232
assert foo.bar.serialized_on_wire == False
33+
34+
35+
def test_enum_as_int_json():
36+
class TestEnum(betterproto.Enum):
37+
ZERO = 0
38+
ONE = 1
39+
40+
@dataclass
41+
class Foo(betterproto.Message):
42+
bar: TestEnum = betterproto.enum_field(1)
43+
44+
# JSON strings are supported, but ints should still be supported too.
45+
foo = Foo().from_dict({"bar": 1})
46+
assert foo.bar == TestEnum.ONE
47+
48+
# Plain-ol'-ints should serialize properly too.
49+
foo.bar = 1
50+
assert foo.to_dict() == {"bar": "ONE"}

0 commit comments

Comments
 (0)