Skip to content

Commit c79535b

Browse files
committed
Support Duration/Timestamp Google well-known types
1 parent 5daf61f commit c79535b

File tree

3 files changed

+135
-3
lines changed

3 files changed

+135
-3
lines changed

betterproto/__init__.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import struct
66
from abc import ABC
77
from base64 import b64encode, b64decode
8+
from datetime import datetime, timedelta, timezone
89
from typing import (
910
Any,
1011
AsyncGenerator,
@@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
295296
elif proto_type == TYPE_STRING:
296297
return value.encode("utf-8")
297298
elif proto_type == TYPE_MESSAGE:
299+
if isinstance(value, datetime):
300+
# Convert the `datetime` to a timestamp message.
301+
seconds = int(value.timestamp())
302+
nanos = int(value.microsecond * 1e3)
303+
value = _Timestamp(seconds=seconds, nanos=nanos)
304+
elif isinstance(value, timedelta):
305+
# Convert the `timedelta` to a duration message.
306+
total_ms = value // timedelta(microseconds=1)
307+
seconds = int(total_ms / 1e6)
308+
nanos = int((total_ms % 1e6) * 1e3)
309+
value = _Duration(seconds=seconds, nanos=nanos)
310+
298311
return bytes(value)
299312

300313
return value
@@ -399,6 +412,7 @@ def __post_init__(self) -> None:
399412
meta = FieldMetadata.get(field)
400413

401414
if meta.group:
415+
# This is part of a one-of group.
402416
group_map["fields"][field.name] = meta.group
403417

404418
if meta.group not in group_map["groups"]:
@@ -530,6 +544,9 @@ def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> A
530544
elif issubclass(t, Enum):
531545
# Enums always default to zero.
532546
value = 0
547+
elif t == datetime:
548+
# Offsets are relative to 1970-01-01T00:00:00Z
549+
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
533550
else:
534551
# This is either a primitive scalar or another message type. Calling
535552
# it should result in its zero value.
@@ -558,8 +575,14 @@ def _postprocess_single(
558575
value = value.decode("utf-8")
559576
elif meta.proto_type == TYPE_MESSAGE:
560577
cls = self._cls_for(field)
561-
value = cls().parse(value)
562-
value._serialized_on_wire = True
578+
579+
if cls == datetime:
580+
value = _Timestamp().parse(value).to_datetime()
581+
elif cls == timedelta:
582+
value = _Duration().parse(value).to_timedelta()
583+
else:
584+
value = cls().parse(value)
585+
value._serialized_on_wire = True
563586
elif meta.proto_type == TYPE_MAP:
564587
# TODO: This is slow, use a cache to make it faster since each
565588
# key/value pair will recreate the class.
@@ -646,7 +669,11 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
646669
v = getattr(self, field.name)
647670
cased_name = casing(field.name).rstrip("_")
648671
if meta.proto_type == "message":
649-
if isinstance(v, list):
672+
if isinstance(v, datetime):
673+
output[cased_name] = _Timestamp.to_json(v)
674+
elif isinstance(v, timedelta):
675+
output[cased_name] = _Duration.to_json(v)
676+
elif isinstance(v, list):
650677
# Convert each item.
651678
v = [i.to_dict() for i in v]
652679
output[cased_name] = v
@@ -701,6 +728,12 @@ def from_dict(self: T, value: dict) -> T:
701728
cls = self._cls_for(field)
702729
for i in range(len(value[key])):
703730
v.append(cls().from_dict(value[key][i]))
731+
elif isinstance(v, datetime):
732+
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
733+
setattr(self, field.name, v)
734+
elif isinstance(v, timedelta):
735+
v = timedelta(seconds=float(value[key][:-1]))
736+
setattr(self, field.name, v)
704737
else:
705738
v.from_dict(value[key])
706739
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
760793
return (field.name, getattr(message, field.name))
761794

762795

796+
@dataclasses.dataclass
797+
class _Duration(Message):
798+
# Signed seconds of the span of time. Must be from -315,576,000,000 to
799+
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
800+
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
801+
seconds: int = int64_field(1)
802+
# Signed fractions of a second at nanosecond resolution of the span of time.
803+
# Durations less than one second are represented with a 0 `seconds` field and
804+
# a positive or negative `nanos` field. For durations of one second or more,
805+
# a non-zero value for the `nanos` field must be of the same sign as the
806+
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
807+
nanos: int = int32_field(2)
808+
809+
def to_timedelta(self) -> timedelta:
810+
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
811+
812+
@staticmethod
813+
def to_json(delta: timedelta) -> str:
814+
parts = str(delta.total_seconds()).split(".")
815+
if len(parts) > 1:
816+
while len(parts[1]) not in [3, 6, 9]:
817+
parts[1] = parts[1] + "0"
818+
return ".".join(parts) + "s"
819+
820+
821+
@dataclasses.dataclass
822+
class _Timestamp(Message):
823+
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
824+
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
825+
seconds: int = int64_field(1)
826+
# Non-negative fractions of a second at nanosecond resolution. Negative
827+
# second values with fractions must still have non-negative nanos values that
828+
# count forward in time. Must be from 0 to 999,999,999 inclusive.
829+
nanos: int = int32_field(2)
830+
831+
def to_datetime(self) -> datetime:
832+
ts = self.seconds + (self.nanos / 1e9)
833+
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
834+
return datetime.fromtimestamp(ts, tz=timezone.utc)
835+
836+
@staticmethod
837+
def to_json(dt: datetime) -> str:
838+
nanos = dt.microsecond * 1e3
839+
copy = dt.replace(microsecond=0, tzinfo=None)
840+
result = copy.isoformat()
841+
if (nanos % 1e9) == 0:
842+
# If there are 0 fractional digits, the fractional
843+
# point '.' should be omitted when serializing.
844+
return result + 'Z'
845+
if (nanos % 1e6) == 0:
846+
# Serialize 3 fractional digits.
847+
return result + '.%03dZ' % (nanos / 1e6)
848+
if (nanos % 1e3) == 0:
849+
# Serialize 6 fractional digits.
850+
return result + '.%06dZ' % (nanos / 1e3)
851+
# Serialize 9 fractional digits.
852+
return result + '.%09dZ' % nanos
853+
854+
763855
class ServiceStub(ABC):
764856
"""
765857
Base class for async gRPC service stubs.

betterproto/plugin.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@
3030
from betterproto.casing import safe_snake_case
3131

3232

33+
WRAPPER_TYPES = {
34+
"google.protobuf.DoubleValue": "float",
35+
"google.protobuf.FloatValue": "float",
36+
"google.protobuf.Int64Value": "int",
37+
"google.protobuf.UInt64Value": "int",
38+
"google.protobuf.Int32Value": "int",
39+
"google.protobuf.UInt32Value": "int",
40+
"google.protobuf.BoolValue": "bool",
41+
"google.protobuf.StringValue": "str",
42+
"google.protobuf.BytesValue": "bytes",
43+
}
44+
45+
3346
def get_ref_type(package: str, imports: set, type_name: str) -> str:
3447
"""
3548
Return a Python type name for a proto type reference. Adds the import if
@@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
3952
# because by convention packages are lowercase and message/enum types are
4053
# pascal-cased. May require refactoring in the future.
4154
type_name = type_name.lstrip(".")
55+
56+
if type_name in WRAPPER_TYPES:
57+
return f"Optional[{WRAPPER_TYPES[type_name]}]"
58+
59+
if type_name == "google.protobuf.Duration":
60+
return "timedelta"
61+
62+
if type_name == "google.protobuf.Timestamp":
63+
return "datetime"
64+
4265
if type_name.startswith(package):
4366
parts = type_name.lstrip(package).lstrip(".").split(".")
4467
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
@@ -152,6 +175,9 @@ def generate_code(request, response):
152175
output_map = {}
153176
for proto_file in request.proto_file:
154177
out = proto_file.package
178+
if out == "google.protobuf":
179+
continue
180+
155181
if not out:
156182
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
157183

@@ -169,6 +195,7 @@ def generate_code(request, response):
169195
"package": package,
170196
"files": [f.name for f in options["files"]],
171197
"imports": set(),
198+
"datetime_imports": set(),
172199
"typing_imports": set(),
173200
"messages": [],
174201
"enums": [],
@@ -258,6 +285,14 @@ def generate_code(request, response):
258285
if f.HasField("oneof_index"):
259286
one_of = item.oneof_decl[f.oneof_index].name
260287

288+
if "Optional[" in t:
289+
output["typing_imports"].add("Optional")
290+
291+
if "timedelta" in t:
292+
output["datetime_imports"].add("timedelta")
293+
elif "datetime" in t:
294+
output["datetime_imports"].add("datetime")
295+
261296
data["properties"].append(
262297
{
263298
"name": f.name,
@@ -346,6 +381,7 @@ def generate_code(request, response):
346381
output["services"].append(data)
347382

348383
output["imports"] = sorted(output["imports"])
384+
output["datetime_imports"] = sorted(output["datetime_imports"])
349385
output["typing_imports"] = sorted(output["typing_imports"])
350386

351387
# Fill response

betterproto/templates/template.py

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)