Skip to content

Commit 2f658df

Browse files
Use betterproto wrapper classes, extract to module for testability
1 parent b813d1c commit 2f658df

File tree

7 files changed

+165
-85
lines changed

7 files changed

+165
-85
lines changed

betterproto/compile/__init__.py

Whitespace-only changes.

betterproto/compile/importing.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Dict, Type
2+
3+
import stringcase
4+
5+
from betterproto import safe_snake_case
6+
from betterproto.lib.google import protobuf as google_protobuf
7+
8+
WRAPPER_TYPES: Dict[str, Type] = {
9+
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
10+
"google.protobuf.FloatValue": google_protobuf.FloatValue,
11+
"google.protobuf.Int32Value": google_protobuf.Int32Value,
12+
"google.protobuf.Int64Value": google_protobuf.Int64Value,
13+
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
14+
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
15+
"google.protobuf.BoolValue": google_protobuf.BoolValue,
16+
"google.protobuf.StringValue": google_protobuf.StringValue,
17+
"google.protobuf.BytesValue": google_protobuf.BytesValue,
18+
}
19+
20+
21+
def get_ref_type(
22+
package: str, imports: set, type_name: str, unwrap: bool = True
23+
) -> str:
24+
"""
25+
Return a Python type name for a proto type reference. Adds the import if
26+
necessary. Unwraps well known type if required.
27+
"""
28+
# If the package name is a blank string, then this should still work
29+
# because by convention packages are lowercase and message/enum types are
30+
# pascal-cased. May require refactoring in the future.
31+
type_name = type_name.lstrip(".")
32+
33+
is_wrapper = type_name in WRAPPER_TYPES
34+
35+
if unwrap:
36+
if is_wrapper:
37+
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
38+
return f"Optional[{wrapped_type.__name__}]"
39+
40+
if type_name == "google.protobuf.Duration":
41+
return "timedelta"
42+
43+
if type_name == "google.protobuf.Timestamp":
44+
return "datetime"
45+
46+
if type_name.startswith(package):
47+
parts = type_name.lstrip(package).lstrip(".").split(".")
48+
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
49+
# This is the current package, which has nested types flattened.
50+
# foo.bar_thing => FooBarThing
51+
cased = [stringcase.pascalcase(part) for part in parts]
52+
type_name = f'"{"".join(cased)}"'
53+
54+
# Use precompiled classes for google.protobuf.* objects
55+
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
56+
type_name = type_name.rsplit(".", maxsplit=1)[1]
57+
import_package = "betterproto.lib.google.protobuf"
58+
import_alias = safe_snake_case(import_package)
59+
imports.add(f"import {import_package} as {import_alias}")
60+
return f"{import_alias}.{type_name}"
61+
62+
if "." in type_name:
63+
# This is imported from another package. No need
64+
# to use a forward ref and we need to add the import.
65+
parts = type_name.split(".")
66+
parts[-1] = stringcase.pascalcase(parts[-1])
67+
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
68+
type_name = f"{parts[-2]}.{parts[-1]}"
69+
70+
return type_name

betterproto/plugin.py

Lines changed: 2 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import stringcase
77
import sys
88
import textwrap
9-
from collections import defaultdict
10-
from typing import Dict, List, Optional, Type
9+
from typing import List
1110
from betterproto.casing import safe_snake_case
11+
from betterproto.compile.importing import get_ref_type
1212
import betterproto
1313

1414
try:
@@ -35,78 +35,6 @@
3535
raise SystemExit(1)
3636

3737

38-
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
39-
lambda: None,
40-
{
41-
"google.protobuf.DoubleValue": google_wrappers.DoubleValue,
42-
"google.protobuf.FloatValue": google_wrappers.FloatValue,
43-
"google.protobuf.Int64Value": google_wrappers.Int64Value,
44-
"google.protobuf.UInt64Value": google_wrappers.UInt64Value,
45-
"google.protobuf.Int32Value": google_wrappers.Int32Value,
46-
"google.protobuf.UInt32Value": google_wrappers.UInt32Value,
47-
"google.protobuf.BoolValue": google_wrappers.BoolValue,
48-
"google.protobuf.StringValue": google_wrappers.StringValue,
49-
"google.protobuf.BytesValue": google_wrappers.BytesValue,
50-
},
51-
)
52-
53-
54-
def get_ref_type(
55-
package: str, imports: set, type_name: str, unwrap: bool = True
56-
) -> str:
57-
"""
58-
Return a Python type name for a proto type reference. Adds the import if
59-
necessary. Unwraps well known type if required.
60-
"""
61-
# If the package name is a blank string, then this should still work
62-
# because by convention packages are lowercase and message/enum types are
63-
# pascal-cased. May require refactoring in the future.
64-
type_name = type_name.lstrip(".")
65-
66-
# Check if type is wrapper.
67-
wrapper_class = WRAPPER_TYPES[type_name]
68-
69-
if unwrap:
70-
if wrapper_class:
71-
wrapped_type = type(wrapper_class().value)
72-
return f"Optional[{wrapped_type.__name__}]"
73-
74-
if type_name == "google.protobuf.Duration":
75-
return "timedelta"
76-
77-
if type_name == "google.protobuf.Timestamp":
78-
return "datetime"
79-
elif wrapper_class:
80-
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
81-
return f"{wrapper_class.__name__}"
82-
83-
if type_name.startswith(package):
84-
parts = type_name.lstrip(package).lstrip(".").split(".")
85-
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
86-
# This is the current package, which has nested types flattened.
87-
# foo.bar_thing => FooBarThing
88-
cased = [stringcase.pascalcase(part) for part in parts]
89-
type_name = f'"{"".join(cased)}"'
90-
91-
# Use precompiled classes for google.protobuf.* objects
92-
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
93-
type_name = type_name.rsplit(".", maxsplit=1)[1]
94-
import_package = "betterproto.lib.google.protobuf"
95-
import_alias = safe_snake_case(import_package)
96-
imports.add(f"import {import_package} as {import_alias}")
97-
return f"{import_alias}.{type_name}"
98-
99-
if "." in type_name:
100-
# This is imported from another package. No need
101-
# to use a forward ref and we need to add the import.
102-
parts = type_name.split(".")
103-
parts[-1] = stringcase.pascalcase(parts[-1])
104-
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
105-
type_name = f"{parts[-2]}.{parts[-1]}"
106-
107-
return type_name
108-
109-
11038
def py_type(
11139
package: str,
11240
imports: set,

betterproto/tests/__init__.py

Whitespace-only changes.

betterproto/tests/inputs/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"import_circular_dependency", # failing because of other bugs now
99
"import_packages_same_name", # 25
1010
"oneof_enum", # 63
11-
"googletypes_service_returns_empty", # 9
1211
"casing_message_field_uppercase", # 11
1312
"namespace_keywords", # 70
1413
"namespace_builtin_types", # 53
@@ -22,4 +21,5 @@
2221
"service",
2322
"import_service_input_message",
2423
"googletypes_service_returns_empty",
24+
"googletypes_service_returns_googletype",
2525
}

betterproto/tests/inputs/googletypes_response/test_googletypes_response.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, Optional
22

3-
import google.protobuf.wrappers_pb2 as wrappers
3+
import betterproto.lib.google.protobuf as protobuf
44
import pytest
55

66
from betterproto.tests.mocks import MockChannel
@@ -9,15 +9,15 @@
99
)
1010

1111
test_cases = [
12-
(TestStub.get_double, wrappers.DoubleValue, 2.5),
13-
(TestStub.get_float, wrappers.FloatValue, 2.5),
14-
(TestStub.get_int64, wrappers.Int64Value, -64),
15-
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
16-
(TestStub.get_int32, wrappers.Int32Value, -32),
17-
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
18-
(TestStub.get_bool, wrappers.BoolValue, True),
19-
(TestStub.get_string, wrappers.StringValue, "string"),
20-
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
12+
(TestStub.get_double, protobuf.DoubleValue, 2.5),
13+
(TestStub.get_float, protobuf.FloatValue, 2.5),
14+
(TestStub.get_int64, protobuf.Int64Value, -64),
15+
(TestStub.get_u_int64, protobuf.UInt64Value, 64),
16+
(TestStub.get_int32, protobuf.Int32Value, -32),
17+
(TestStub.get_u_int32, protobuf.UInt32Value, 32),
18+
(TestStub.get_bool, protobuf.BoolValue, True),
19+
(TestStub.get_string, protobuf.StringValue, "string"),
20+
(TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
2121
]
2222

2323

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pytest
2+
3+
from ..compile.importing import get_ref_type
4+
5+
6+
@pytest.mark.parametrize(
7+
["google_type", "expected_name", "expected_import"],
8+
[
9+
(
10+
".google.protobuf.Empty",
11+
"betterproto_lib_google_protobuf.Empty",
12+
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
13+
),
14+
(
15+
".google.protobuf.Struct",
16+
"betterproto_lib_google_protobuf.Struct",
17+
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
18+
),
19+
(
20+
".google.protobuf.ListValue",
21+
"betterproto_lib_google_protobuf.ListValue",
22+
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
23+
),
24+
(
25+
".google.protobuf.Value",
26+
"betterproto_lib_google_protobuf.Value",
27+
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
28+
),
29+
],
30+
)
31+
def test_import_google_wellknown_types_non_wrappers(
32+
google_type: str, expected_name: str, expected_import: str
33+
):
34+
imports = set()
35+
name = get_ref_type(package="", imports=imports, type_name=google_type)
36+
37+
assert name == expected_name
38+
assert imports.__contains__(expected_import)
39+
40+
41+
@pytest.mark.parametrize(
42+
["google_type", "expected_name"],
43+
[
44+
(".google.protobuf.DoubleValue", "Optional[float]"),
45+
(".google.protobuf.FloatValue", "Optional[float]"),
46+
(".google.protobuf.Int32Value", "Optional[int]"),
47+
(".google.protobuf.Int64Value", "Optional[int]"),
48+
(".google.protobuf.UInt32Value", "Optional[int]"),
49+
(".google.protobuf.UInt64Value", "Optional[int]"),
50+
(".google.protobuf.BoolValue", "Optional[bool]"),
51+
(".google.protobuf.StringValue", "Optional[str]"),
52+
(".google.protobuf.BytesValue", "Optional[bytes]"),
53+
],
54+
)
55+
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
56+
imports = set()
57+
name = get_ref_type(package="", imports=imports, type_name=google_type)
58+
59+
assert name == expected_name
60+
assert imports == set()
61+
62+
63+
@pytest.mark.parametrize(
64+
["google_type", "expected_name"],
65+
[
66+
(".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"),
67+
(".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"),
68+
(".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"),
69+
(".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"),
70+
(".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"),
71+
(".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"),
72+
(".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"),
73+
(".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"),
74+
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
75+
],
76+
)
77+
def test_importing_google_wrappers_without_unwrapping(
78+
google_type: str, expected_name: str
79+
):
80+
name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
81+
82+
assert name == expected_name

0 commit comments

Comments
 (0)