Skip to content

Commit 9a45ea9

Browse files
authored
Merge pull request #78 from boukeversteegh/pr/google
Basic general support for Google Protobuf
2 parents eec24e4 + f7769a1 commit 9a45ea9

File tree

21 files changed

+1601
-161
lines changed

21 files changed

+1601
-161
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ Google provides several well-known message types like a timestamp, duration, and
256256
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
257257
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
258258
| `google.protobuf.*Value` | `Optional[...]` | `None` |
259+
| `google.protobuf.*` | `betterproto.lib.google.protobuf.*` | `None` |
259260

260261
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
261262
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
@@ -354,6 +355,25 @@ $ pipenv run generate
354355
$ pipenv run test
355356
```
356357

358+
### (Re)compiling Google Well-known Types
359+
360+
Betterproto includes compiled versions for Google's well-known types at [betterproto/lib/google](betterproto/lib/google).
361+
Be sure to regenerate these files when modifying the plugin output format, and validate by running the tests.
362+
363+
Normally, the plugin does not compile any references to `google.protobuf`, since they are pre-compiled. To force compilation of `google.protobuf`, use the option `--custom_opt=INCLUDE_GOOGLE`.
364+
365+
Assuming your `google.protobuf` source files (included with all releases of `protoc`) are located in `/usr/local/include`, you can regenerate them as follows:
366+
367+
```sh
368+
protoc \
369+
--plugin=protoc-gen-custom=betterproto/plugin.py \
370+
--custom_opt=INCLUDE_GOOGLE \
371+
--custom_out=betterproto/lib \
372+
-I /usr/local/include/ \
373+
/usr/local/include/google/protobuf/*.proto
374+
```
375+
376+
357377
### TODO
358378

359379
- [x] Fixed length fields

betterproto/__init__.py

Lines changed: 26 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -941,19 +941,23 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
941941
return (field_name, getattr(message, field_name))
942942

943943

944-
@dataclasses.dataclass
945-
class _Duration(Message):
946-
# Signed seconds of the span of time. Must be from -315,576,000,000 to
947-
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
948-
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
949-
seconds: int = int64_field(1)
950-
# Signed fractions of a second at nanosecond resolution of the span of time.
951-
# Durations less than one second are represented with a 0 `seconds` field and
952-
# a positive or negative `nanos` field. For durations of one second or more,
953-
# a non-zero value for the `nanos` field must be of the same sign as the
954-
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
955-
nanos: int = int32_field(2)
944+
# Circular import workaround: google.protobuf depends on base classes defined above.
945+
from .lib.google.protobuf import (
946+
Duration,
947+
Timestamp,
948+
BoolValue,
949+
BytesValue,
950+
DoubleValue,
951+
FloatValue,
952+
Int32Value,
953+
Int64Value,
954+
StringValue,
955+
UInt32Value,
956+
UInt64Value,
957+
)
958+
956959

960+
class _Duration(Duration):
957961
def to_timedelta(self) -> timedelta:
958962
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
959963

@@ -966,16 +970,7 @@ def delta_to_json(delta: timedelta) -> str:
966970
return ".".join(parts) + "s"
967971

968972

969-
@dataclasses.dataclass
970-
class _Timestamp(Message):
971-
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
972-
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
973-
seconds: int = int64_field(1)
974-
# Non-negative fractions of a second at nanosecond resolution. Negative
975-
# second values with fractions must still have non-negative nanos values that
976-
# count forward in time. Must be from 0 to 999,999,999 inclusive.
977-
nanos: int = int32_field(2)
978-
973+
class _Timestamp(Timestamp):
979974
def to_datetime(self) -> datetime:
980975
ts = self.seconds + (self.nanos / 1e9)
981976
return datetime.fromtimestamp(ts, tz=timezone.utc)
@@ -1016,63 +1011,18 @@ def from_dict(self: T, value: Any) -> T:
10161011
return self
10171012

10181013

1019-
@dataclasses.dataclass
1020-
class _BoolValue(_WrappedMessage):
1021-
value: bool = bool_field(1)
1022-
1023-
1024-
@dataclasses.dataclass
1025-
class _Int32Value(_WrappedMessage):
1026-
value: int = int32_field(1)
1027-
1028-
1029-
@dataclasses.dataclass
1030-
class _UInt32Value(_WrappedMessage):
1031-
value: int = uint32_field(1)
1032-
1033-
1034-
@dataclasses.dataclass
1035-
class _Int64Value(_WrappedMessage):
1036-
value: int = int64_field(1)
1037-
1038-
1039-
@dataclasses.dataclass
1040-
class _UInt64Value(_WrappedMessage):
1041-
value: int = uint64_field(1)
1042-
1043-
1044-
@dataclasses.dataclass
1045-
class _FloatValue(_WrappedMessage):
1046-
value: float = float_field(1)
1047-
1048-
1049-
@dataclasses.dataclass
1050-
class _DoubleValue(_WrappedMessage):
1051-
value: float = double_field(1)
1052-
1053-
1054-
@dataclasses.dataclass
1055-
class _StringValue(_WrappedMessage):
1056-
value: str = string_field(1)
1057-
1058-
1059-
@dataclasses.dataclass
1060-
class _BytesValue(_WrappedMessage):
1061-
value: bytes = bytes_field(1)
1062-
1063-
10641014
def _get_wrapper(proto_type: str) -> Type:
10651015
"""Get the wrapper message class for a wrapped type."""
10661016
return {
1067-
TYPE_BOOL: _BoolValue,
1068-
TYPE_INT32: _Int32Value,
1069-
TYPE_UINT32: _UInt32Value,
1070-
TYPE_INT64: _Int64Value,
1071-
TYPE_UINT64: _UInt64Value,
1072-
TYPE_FLOAT: _FloatValue,
1073-
TYPE_DOUBLE: _DoubleValue,
1074-
TYPE_STRING: _StringValue,
1075-
TYPE_BYTES: _BytesValue,
1017+
TYPE_BOOL: BoolValue,
1018+
TYPE_INT32: Int32Value,
1019+
TYPE_UINT32: UInt32Value,
1020+
TYPE_INT64: Int64Value,
1021+
TYPE_UINT64: UInt64Value,
1022+
TYPE_FLOAT: FloatValue,
1023+
TYPE_DOUBLE: DoubleValue,
1024+
TYPE_STRING: StringValue,
1025+
TYPE_BYTES: BytesValue,
10761026
}[proto_type]
10771027

10781028

betterproto/compile/__init__.py

Whitespace-only changes.

betterproto/compile/importing.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Dict, Type
2+
3+
import stringcase
4+
5+
from betterproto import safe_snake_case
6+
from betterproto.lib.google import protobuf as google_protobuf
7+
8+
WRAPPER_TYPES: Dict[str, Type] = {
9+
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
10+
"google.protobuf.FloatValue": google_protobuf.FloatValue,
11+
"google.protobuf.Int32Value": google_protobuf.Int32Value,
12+
"google.protobuf.Int64Value": google_protobuf.Int64Value,
13+
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
14+
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
15+
"google.protobuf.BoolValue": google_protobuf.BoolValue,
16+
"google.protobuf.StringValue": google_protobuf.StringValue,
17+
"google.protobuf.BytesValue": google_protobuf.BytesValue,
18+
}
19+
20+
21+
def get_ref_type(
22+
package: str, imports: set, type_name: str, unwrap: bool = True
23+
) -> str:
24+
"""
25+
Return a Python type name for a proto type reference. Adds the import if
26+
necessary. Unwraps well known type if required.
27+
"""
28+
# If the package name is a blank string, then this should still work
29+
# because by convention packages are lowercase and message/enum types are
30+
# pascal-cased. May require refactoring in the future.
31+
type_name = type_name.lstrip(".")
32+
33+
is_wrapper = type_name in WRAPPER_TYPES
34+
35+
if unwrap:
36+
if is_wrapper:
37+
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
38+
return f"Optional[{wrapped_type.__name__}]"
39+
40+
if type_name == "google.protobuf.Duration":
41+
return "timedelta"
42+
43+
if type_name == "google.protobuf.Timestamp":
44+
return "datetime"
45+
46+
if type_name.startswith(package):
47+
parts = type_name.lstrip(package).lstrip(".").split(".")
48+
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
49+
# This is the current package, which has nested types flattened.
50+
# foo.bar_thing => FooBarThing
51+
cased = [stringcase.pascalcase(part) for part in parts]
52+
type_name = f'"{"".join(cased)}"'
53+
54+
# Use precompiled classes for google.protobuf.* objects
55+
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
56+
type_name = type_name.rsplit(".", maxsplit=1)[1]
57+
import_package = "betterproto.lib.google.protobuf"
58+
import_alias = safe_snake_case(import_package)
59+
imports.add(f"import {import_package} as {import_alias}")
60+
return f"{import_alias}.{type_name}"
61+
62+
if "." in type_name:
63+
# This is imported from another package. No need
64+
# to use a forward ref and we need to add the import.
65+
parts = type_name.split(".")
66+
parts[-1] = stringcase.pascalcase(parts[-1])
67+
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
68+
type_name = f"{parts[-2]}.{parts[-1]}"
69+
70+
return type_name

betterproto/lib/__init__.py

Whitespace-only changes.

betterproto/lib/google/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)