Skip to content

Commit 035793a

Browse files
committed
Support wrapper types
1 parent c79535b commit 035793a

File tree

7 files changed

+233
-32
lines changed

7 files changed

+233
-32
lines changed

README.md

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,53 @@ Again this is a little different than the official Google code generator:
238238
["foo", "foo's value"]
239239
```
240240

241+
### Well-Known Google Types
242+
243+
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
244+
245+
| Google Message | Python Type | Default |
246+
| --------------------------- | ---------------------------------------- | ---------------------- |
247+
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
248+
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
249+
| `google.protobuf.*Value` | `Optional[...]` | `None` |
250+
251+
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
252+
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
253+
254+
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
255+
256+
```protobuf
257+
syntax = "proto3";
258+
259+
import "google/protobuf/duration.proto";
260+
import "google/protobuf/timestamp.proto";
261+
import "google/protobuf/wrappers.proto";
262+
263+
message Test {
264+
google.protobuf.BoolValue maybe = 1;
265+
google.protobuf.Timestamp ts = 2;
266+
google.protobuf.Duration duration = 3;
267+
}
268+
```
269+
270+
You can do stuff like:
271+
272+
```py
273+
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
274+
>>> t
275+
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
276+
277+
>>> t.ts - t.duration
278+
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
279+
280+
>>> t.ts.isoformat()
281+
'2019-01-01T12:00:00+00:00'
282+
283+
>>> t.maybe = None
284+
>>> t.to_dict()
285+
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
286+
```
287+
241288
## Development
242289

243290
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
@@ -295,7 +342,7 @@ $ pipenv run tests
295342
- [x] Bytes as base64
296343
- [ ] Any support
297344
- [x] Enum strings
298-
- [ ] Well known types support (timestamp, duration, wrappers)
345+
- [x] Well known types support (timestamp, duration, wrappers)
299346
- [x] Support different casing (orig vs. camel vs. others?)
300347
- [ ] Async service stubs
301348
- [x] Unary-unary

betterproto/__init__.py

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@
105105
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
106106

107107

108+
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
109+
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
110+
111+
108112
class Casing(enum.Enum):
109113
"""Casing constants for serialization."""
110114

@@ -128,9 +132,11 @@ class FieldMetadata:
128132
# Protobuf type name
129133
proto_type: str
130134
# Map information if the proto_type is a map
131-
map_types: Optional[Tuple[str, str]]
135+
map_types: Optional[Tuple[str, str]] = None
132136
# Groups several "one-of" fields together
133-
group: Optional[str]
137+
group: Optional[str] = None
138+
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
139+
wraps: Optional[str] = None
134140

135141
@staticmethod
136142
def get(field: dataclasses.Field) -> "FieldMetadata":
@@ -144,11 +150,14 @@ def dataclass_field(
144150
*,
145151
map_types: Optional[Tuple[str, str]] = None,
146152
group: Optional[str] = None,
153+
wraps: Optional[str] = None,
147154
) -> dataclasses.Field:
148155
"""Creates a dataclass field with attached protobuf metadata."""
149156
return dataclasses.field(
150157
default=PLACEHOLDER,
151-
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
158+
metadata={
159+
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
160+
},
152161
)
153162

154163

@@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
221230
return dataclass_field(number, TYPE_BYTES, group=group)
222231

223232

224-
def message_field(number: int, group: Optional[str] = None) -> Any:
225-
return dataclass_field(number, TYPE_MESSAGE, group=group)
233+
def message_field(
234+
number: int, group: Optional[str] = None, wraps: Optional[str] = None
235+
) -> Any:
236+
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
226237

227238

228239
def map_field(
@@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
273284
return bytes(b + [bits])
274285

275286

276-
def _preprocess_single(proto_type: str, value: Any) -> bytes:
287+
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
277288
"""Adjusts values before serialization."""
278289
if proto_type in [
279290
TYPE_ENUM,
@@ -307,17 +318,26 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
307318
seconds = int(total_ms / 1e6)
308319
nanos = int((total_ms % 1e6) * 1e3)
309320
value = _Duration(seconds=seconds, nanos=nanos)
321+
elif wraps:
322+
if value is None:
323+
return b""
324+
value = _get_wrapper(wraps)(value=value)
310325

311326
return bytes(value)
312327

313328
return value
314329

315330

316331
def _serialize_single(
317-
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
332+
field_number: int,
333+
proto_type: str,
334+
value: Any,
335+
*,
336+
serialize_empty: bool = False,
337+
wraps: str = "",
318338
) -> bytes:
319339
"""Serializes a single field and value."""
320-
value = _preprocess_single(proto_type, value)
340+
value = _preprocess_single(proto_type, wraps, value)
321341

322342
output = b""
323343
if proto_type in WIRE_VARINT_TYPES:
@@ -330,7 +350,7 @@ def _serialize_single(
330350
key = encode_varint((field_number << 3) | 1)
331351
output += key + value
332352
elif proto_type in WIRE_LEN_DELIM_TYPES:
333-
if len(value) or serialize_empty:
353+
if len(value) or serialize_empty or wraps:
334354
key = encode_varint((field_number << 3) | 2)
335355
output += key + encode_varint(len(value)) + value
336356
else:
@@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
370390
while i < len(value):
371391
start = i
372392
num_wire, i = decode_varint(value, i)
373-
# print(num_wire, i)
374393
number = num_wire >> 3
375394
wire_type = num_wire & 0x7
376395

@@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
386405
elif wire_type == 5:
387406
decoded, i = value[i : i + 4], i + 4
388407

389-
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
390-
391408
yield ParsedField(
392409
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
393410
)
@@ -462,6 +479,11 @@ def __bytes__(self) -> bytes:
462479
meta = FieldMetadata.get(field)
463480
value = getattr(self, field.name)
464481

482+
if value is None:
483+
# Optional items should be skipped. This is used for the Google
484+
# wrapper types.
485+
continue
486+
465487
# Being selected in a a group means this field is the one that is
466488
# currently set in a `oneof` group, so it must be serialized even
467489
# if the value is the default zero value.
@@ -491,11 +513,13 @@ def __bytes__(self) -> bytes:
491513
# treat it like a field of raw bytes.
492514
buf = b""
493515
for item in value:
494-
buf += _preprocess_single(meta.proto_type, item)
516+
buf += _preprocess_single(meta.proto_type, "", item)
495517
output += _serialize_single(meta.number, TYPE_BYTES, buf)
496518
else:
497519
for item in value:
498-
output += _serialize_single(meta.number, meta.proto_type, item)
520+
output += _serialize_single(
521+
meta.number, meta.proto_type, item, wraps=meta.wraps
522+
)
499523
elif isinstance(value, dict):
500524
for k, v in value.items():
501525
assert meta.map_types
@@ -504,7 +528,11 @@ def __bytes__(self) -> bytes:
504528
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
505529
else:
506530
output += _serialize_single(
507-
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
531+
meta.number,
532+
meta.proto_type,
533+
value,
534+
serialize_empty=serialize_empty,
535+
wraps=meta.wraps,
508536
)
509537

510538
return output + self._unknown_fields
@@ -546,7 +574,7 @@ def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> A
546574
value = 0
547575
elif t == datetime:
548576
# Offsets are relative to 1970-01-01T00:00:00Z
549-
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
577+
value = DATETIME_ZERO
550578
else:
551579
# This is either a primitive scalar or another message type. Calling
552580
# it should result in its zero value.
@@ -580,6 +608,10 @@ def _postprocess_single(
580608
value = _Timestamp().parse(value).to_datetime()
581609
elif cls == timedelta:
582610
value = _Duration().parse(value).to_timedelta()
611+
elif meta.wraps:
612+
# This is a Google wrapper value message around a single
613+
# scalar type.
614+
value = _get_wrapper(meta.wraps)().parse(value).value
583615
else:
584616
value = cls().parse(value)
585617
value._serialized_on_wire = True
@@ -670,9 +702,14 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
670702
cased_name = casing(field.name).rstrip("_")
671703
if meta.proto_type == "message":
672704
if isinstance(v, datetime):
673-
output[cased_name] = _Timestamp.to_json(v)
705+
if v != DATETIME_ZERO:
706+
output[cased_name] = _Timestamp.to_json(v)
674707
elif isinstance(v, timedelta):
675-
output[cased_name] = _Duration.to_json(v)
708+
if v != timedelta(0):
709+
output[cased_name] = _Duration.to_json(v)
710+
elif meta.wraps:
711+
if v is not None:
712+
output[cased_name] = v
676713
elif isinstance(v, list):
677714
# Convert each item.
678715
v = [i.to_dict() for i in v]
@@ -723,17 +760,20 @@ def from_dict(self: T, value: dict) -> T:
723760
if value[key] is not None:
724761
if meta.proto_type == "message":
725762
v = getattr(self, field.name)
726-
# print(v, value[key])
727763
if isinstance(v, list):
728764
cls = self._cls_for(field)
729765
for i in range(len(value[key])):
730766
v.append(cls().from_dict(value[key][i]))
731767
elif isinstance(v, datetime):
732-
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
768+
v = datetime.fromisoformat(
769+
value[key].replace("Z", "+00:00")
770+
)
733771
setattr(self, field.name, v)
734772
elif isinstance(v, timedelta):
735773
v = timedelta(seconds=float(value[key][:-1]))
736774
setattr(self, field.name, v)
775+
elif meta.wraps:
776+
setattr(self, field.name, value[key])
737777
else:
738778
v.from_dict(value[key])
739779
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@@ -830,7 +870,6 @@ class _Timestamp(Message):
830870

831871
def to_datetime(self) -> datetime:
832872
ts = self.seconds + (self.nanos / 1e9)
833-
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
834873
return datetime.fromtimestamp(ts, tz=timezone.utc)
835874

836875
@staticmethod
@@ -839,17 +878,90 @@ def to_json(dt: datetime) -> str:
839878
copy = dt.replace(microsecond=0, tzinfo=None)
840879
result = copy.isoformat()
841880
if (nanos % 1e9) == 0:
842-
# If there are 0 fractional digits, the fractional
843-
# point '.' should be omitted when serializing.
844-
return result + 'Z'
881+
# If there are 0 fractional digits, the fractional
882+
# point '.' should be omitted when serializing.
883+
return result + "Z"
845884
if (nanos % 1e6) == 0:
846-
# Serialize 3 fractional digits.
847-
return result + '.%03dZ' % (nanos / 1e6)
885+
# Serialize 3 fractional digits.
886+
return result + ".%03dZ" % (nanos / 1e6)
848887
if (nanos % 1e3) == 0:
849-
# Serialize 6 fractional digits.
850-
return result + '.%06dZ' % (nanos / 1e3)
888+
# Serialize 6 fractional digits.
889+
return result + ".%06dZ" % (nanos / 1e3)
851890
# Serialize 9 fractional digits.
852-
return result + '.%09dZ' % nanos
891+
return result + ".%09dZ" % nanos
892+
893+
894+
class _WrappedMessage(Message):
895+
"""
896+
Google protobuf wrapper types base class. JSON representation is just the
897+
value itself.
898+
"""
899+
def to_dict(self) -> Any:
900+
return self.value
901+
902+
def from_dict(self, value: Any) -> None:
903+
if value is not None:
904+
self.value = value
905+
906+
907+
@dataclasses.dataclass
908+
class _BoolValue(_WrappedMessage):
909+
value: bool = bool_field(1)
910+
911+
912+
@dataclasses.dataclass
913+
class _Int32Value(_WrappedMessage):
914+
value: int = int32_field(1)
915+
916+
917+
@dataclasses.dataclass
918+
class _UInt32Value(_WrappedMessage):
919+
value: int = uint32_field(1)
920+
921+
922+
@dataclasses.dataclass
923+
class _Int64Value(_WrappedMessage):
924+
value: int = int64_field(1)
925+
926+
927+
@dataclasses.dataclass
928+
class _UInt64Value(_WrappedMessage):
929+
value: int = uint64_field(1)
930+
931+
932+
@dataclasses.dataclass
933+
class _FloatValue(_WrappedMessage):
934+
value: float = float_field(1)
935+
936+
937+
@dataclasses.dataclass
938+
class _DoubleValue(_WrappedMessage):
939+
value: float = double_field(1)
940+
941+
942+
@dataclasses.dataclass
943+
class _StringValue(_WrappedMessage):
944+
value: str = string_field(1)
945+
946+
947+
@dataclasses.dataclass
948+
class _BytesValue(_WrappedMessage):
949+
value: bytes = bytes_field(1)
950+
951+
952+
def _get_wrapper(proto_type: str) -> _WrappedMessage:
953+
"""Get the wrapper message class for a wrapped type."""
954+
return {
955+
TYPE_BOOL: _BoolValue,
956+
TYPE_INT32: _Int32Value,
957+
TYPE_UINT32: _UInt32Value,
958+
TYPE_INT64: _Int64Value,
959+
TYPE_UINT64: _UInt64Value,
960+
TYPE_FLOAT: _FloatValue,
961+
TYPE_DOUBLE: _DoubleValue,
962+
TYPE_STRING: _StringValue,
963+
TYPE_BYTES: _BytesValue,
964+
}[proto_type]
853965

854966

855967
class ServiceStub(ABC):

0 commit comments

Comments
 (0)