|
5 | 5 | import struct
|
6 | 6 | from abc import ABC
|
7 | 7 | from base64 import b64encode, b64decode
|
| 8 | +from datetime import datetime, timedelta, timezone |
8 | 9 | from typing import (
|
9 | 10 | Any,
|
10 | 11 | AsyncGenerator,
|
@@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
295 | 296 | elif proto_type == TYPE_STRING:
|
296 | 297 | return value.encode("utf-8")
|
297 | 298 | elif proto_type == TYPE_MESSAGE:
|
| 299 | + if isinstance(value, datetime): |
| 300 | + # Convert the `datetime` to a timestamp message. |
| 301 | + seconds = int(value.timestamp()) |
| 302 | + nanos = int(value.microsecond * 1e3) |
| 303 | + value = _Timestamp(seconds=seconds, nanos=nanos) |
| 304 | + elif isinstance(value, timedelta): |
| 305 | + # Convert the `timedelta` to a duration message. |
| 306 | + total_ms = value // timedelta(microseconds=1) |
| 307 | + seconds = int(total_ms / 1e6) |
| 308 | + nanos = int((total_ms % 1e6) * 1e3) |
| 309 | + value = _Duration(seconds=seconds, nanos=nanos) |
| 310 | + |
298 | 311 | return bytes(value)
|
299 | 312 |
|
300 | 313 | return value
|
@@ -399,6 +412,7 @@ def __post_init__(self) -> None:
|
399 | 412 | meta = FieldMetadata.get(field)
|
400 | 413 |
|
401 | 414 | if meta.group:
|
| 415 | + # This is part of a one-of group. |
402 | 416 | group_map["fields"][field.name] = meta.group
|
403 | 417 |
|
404 | 418 | if meta.group not in group_map["groups"]:
|
@@ -530,6 +544,9 @@ def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> A
|
530 | 544 | elif issubclass(t, Enum):
|
531 | 545 | # Enums always default to zero.
|
532 | 546 | value = 0
|
| 547 | + elif t == datetime: |
| 548 | + # Offsets are relative to 1970-01-01T00:00:00Z |
| 549 | + value = datetime(1970, 1, 1, tzinfo=timezone.utc) |
533 | 550 | else:
|
534 | 551 | # This is either a primitive scalar or another message type. Calling
|
535 | 552 | # it should result in its zero value.
|
@@ -558,8 +575,14 @@ def _postprocess_single(
|
558 | 575 | value = value.decode("utf-8")
|
559 | 576 | elif meta.proto_type == TYPE_MESSAGE:
|
560 | 577 | cls = self._cls_for(field)
|
561 |
| - value = cls().parse(value) |
562 |
| - value._serialized_on_wire = True |
| 578 | + |
| 579 | + if cls == datetime: |
| 580 | + value = _Timestamp().parse(value).to_datetime() |
| 581 | + elif cls == timedelta: |
| 582 | + value = _Duration().parse(value).to_timedelta() |
| 583 | + else: |
| 584 | + value = cls().parse(value) |
| 585 | + value._serialized_on_wire = True |
563 | 586 | elif meta.proto_type == TYPE_MAP:
|
564 | 587 | # TODO: This is slow, use a cache to make it faster since each
|
565 | 588 | # key/value pair will recreate the class.
|
@@ -646,7 +669,11 @@ def to_dict(self, casing: Casing = Casing.CAMEL) -> dict:
|
646 | 669 | v = getattr(self, field.name)
|
647 | 670 | cased_name = casing(field.name).rstrip("_")
|
648 | 671 | if meta.proto_type == "message":
|
649 |
| - if isinstance(v, list): |
| 672 | + if isinstance(v, datetime): |
| 673 | + output[cased_name] = _Timestamp.to_json(v) |
| 674 | + elif isinstance(v, timedelta): |
| 675 | + output[cased_name] = _Duration.to_json(v) |
| 676 | + elif isinstance(v, list): |
650 | 677 | # Convert each item.
|
651 | 678 | v = [i.to_dict() for i in v]
|
652 | 679 | output[cased_name] = v
|
@@ -701,6 +728,12 @@ def from_dict(self: T, value: dict) -> T:
|
701 | 728 | cls = self._cls_for(field)
|
702 | 729 | for i in range(len(value[key])):
|
703 | 730 | v.append(cls().from_dict(value[key][i]))
|
| 731 | + elif isinstance(v, datetime): |
| 732 | + v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) |
| 733 | + setattr(self, field.name, v) |
| 734 | + elif isinstance(v, timedelta): |
| 735 | + v = timedelta(seconds=float(value[key][:-1])) |
| 736 | + setattr(self, field.name, v) |
704 | 737 | else:
|
705 | 738 | v.from_dict(value[key])
|
706 | 739 | elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
@@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
760 | 793 | return (field.name, getattr(message, field.name))
|
761 | 794 |
|
762 | 795 |
|
| 796 | +@dataclasses.dataclass |
| 797 | +class _Duration(Message): |
| 798 | + # Signed seconds of the span of time. Must be from -315,576,000,000 to |
| 799 | + # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 |
| 800 | + # sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years |
| 801 | + seconds: int = int64_field(1) |
| 802 | + # Signed fractions of a second at nanosecond resolution of the span of time. |
| 803 | + # Durations less than one second are represented with a 0 `seconds` field and |
| 804 | + # a positive or negative `nanos` field. For durations of one second or more, |
| 805 | + # a non-zero value for the `nanos` field must be of the same sign as the |
| 806 | + # `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive. |
| 807 | + nanos: int = int32_field(2) |
| 808 | + |
| 809 | + def to_timedelta(self) -> timedelta: |
| 810 | + return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) |
| 811 | + |
| 812 | + @staticmethod |
| 813 | + def to_json(delta: timedelta) -> str: |
| 814 | + parts = str(delta.total_seconds()).split(".") |
| 815 | + if len(parts) > 1: |
| 816 | + while len(parts[1]) not in [3, 6, 9]: |
| 817 | + parts[1] = parts[1] + "0" |
| 818 | + return ".".join(parts) + "s" |
| 819 | + |
| 820 | + |
| 821 | +@dataclasses.dataclass |
| 822 | +class _Timestamp(Message): |
| 823 | + # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must |
| 824 | + # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. |
| 825 | + seconds: int = int64_field(1) |
| 826 | + # Non-negative fractions of a second at nanosecond resolution. Negative |
| 827 | + # second values with fractions must still have non-negative nanos values that |
| 828 | + # count forward in time. Must be from 0 to 999,999,999 inclusive. |
| 829 | + nanos: int = int32_field(2) |
| 830 | + |
| 831 | + def to_datetime(self) -> datetime: |
| 832 | + ts = self.seconds + (self.nanos / 1e9) |
| 833 | + print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc)) |
| 834 | + return datetime.fromtimestamp(ts, tz=timezone.utc) |
| 835 | + |
| 836 | + @staticmethod |
| 837 | + def to_json(dt: datetime) -> str: |
| 838 | + nanos = dt.microsecond * 1e3 |
| 839 | + copy = dt.replace(microsecond=0, tzinfo=None) |
| 840 | + result = copy.isoformat() |
| 841 | + if (nanos % 1e9) == 0: |
| 842 | + # If there are 0 fractional digits, the fractional |
| 843 | + # point '.' should be omitted when serializing. |
| 844 | + return result + 'Z' |
| 845 | + if (nanos % 1e6) == 0: |
| 846 | + # Serialize 3 fractional digits. |
| 847 | + return result + '.%03dZ' % (nanos / 1e6) |
| 848 | + if (nanos % 1e3) == 0: |
| 849 | + # Serialize 6 fractional digits. |
| 850 | + return result + '.%06dZ' % (nanos / 1e3) |
| 851 | + # Serialize 9 fractional digits. |
| 852 | + return result + '.%09dZ' % nanos |
| 853 | + |
| 854 | + |
763 | 855 | class ServiceStub(ABC):
|
764 | 856 | """
|
765 | 857 | Base class for async gRPC service stubs.
|
|
0 commit comments