Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Commit 6214c86

Browse files
committed
Support more well-known types
1 parent d067d0a commit 6214c86

File tree

3 files changed

+278
-23
lines changed

3 files changed

+278
-23
lines changed

src/betterproto2_compiler/compile/importing.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
TYPE_CHECKING,
66
)
77

8-
from betterproto2_compiler.lib.google import protobuf as google_protobuf
8+
from betterproto2_compiler.known_types import WRAPPED_TYPES
99
from betterproto2_compiler.settings import Settings
1010

1111
from ..casing import safe_snake_case
@@ -14,18 +14,6 @@
1414
if TYPE_CHECKING:
1515
from ..plugin.models import PluginRequestCompiler
1616

17-
WRAPPER_TYPES: dict[str, type] = {
18-
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
19-
".google.protobuf.FloatValue": google_protobuf.FloatValue,
20-
".google.protobuf.Int32Value": google_protobuf.Int32Value,
21-
".google.protobuf.Int64Value": google_protobuf.Int64Value,
22-
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
23-
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
24-
".google.protobuf.BoolValue": google_protobuf.BoolValue,
25-
".google.protobuf.StringValue": google_protobuf.StringValue,
26-
".google.protobuf.BytesValue": google_protobuf.BytesValue,
27-
}
28-
2917

3018
def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]:
3119
"""
@@ -80,15 +68,11 @@ def get_type_reference(
8068
Return a Python type name for a proto type reference. Adds the import if
8169
necessary. Unwraps well known type if required.
8270
"""
83-
if unwrap: # TODO don't hardcode
84-
if source_type == ".google.protobuf.Duration":
85-
return "datetime.timedelta"
86-
87-
elif source_type == ".google.protobuf.Timestamp":
88-
return "datetime.datetime"
89-
9071
source_package, source_type = parse_source_type_name(source_type, request)
9172

73+
if unwrap and (source_package, source_type) in WRAPPED_TYPES:
74+
return WRAPPED_TYPES[(source_package, source_type)]
75+
9276
current_package: list[str] = package.split(".") if package else []
9377
py_package: list[str] = source_package.split(".") if source_package else []
9478
py_type: str = pythonize_class_name(source_type)

src/betterproto2_compiler/known_types/__init__.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22

33
from .any import Any
44
from .duration import Duration
5-
from .google_values import BoolValue, StringValue
5+
from .google_values import (
6+
BoolValue,
7+
BytesValue,
8+
DoubleValue,
9+
FloatValue,
10+
Int32Value,
11+
Int64Value,
12+
StringValue,
13+
UInt32Value,
14+
UInt64Value,
15+
)
616
from .timestamp import Timestamp
717

818
# For each (package, message name), lists the methods that should be added to the message definition.
@@ -28,8 +38,60 @@
2838
Duration.from_wrapped,
2939
Duration.to_wrapped,
3040
],
31-
("google.protobuf", "BoolValue"): [BoolValue.from_wrapped, BoolValue.to_wrapped],
32-
("google.protobuf", "StringValue"): [StringValue.from_wrapped, StringValue.to_wrapped],
41+
("google.protobuf", "BoolValue"): [
42+
BoolValue.from_dict,
43+
BoolValue.to_dict,
44+
BoolValue.from_wrapped,
45+
BoolValue.to_wrapped,
46+
],
47+
("google.protobuf", "Int32Value"): [
48+
Int32Value.from_dict,
49+
Int32Value.to_dict,
50+
Int32Value.from_wrapped,
51+
Int32Value.to_wrapped,
52+
],
53+
("google.protobuf", "Int64Value"): [
54+
Int64Value.from_dict,
55+
Int64Value.to_dict,
56+
Int64Value.from_wrapped,
57+
Int64Value.to_wrapped,
58+
],
59+
("google.protobuf", "UInt32Value"): [
60+
UInt32Value.from_dict,
61+
UInt32Value.to_dict,
62+
UInt32Value.from_wrapped,
63+
UInt32Value.to_wrapped,
64+
],
65+
("google.protobuf", "UInt64Value"): [
66+
UInt64Value.from_dict,
67+
UInt64Value.to_dict,
68+
UInt64Value.from_wrapped,
69+
UInt64Value.to_wrapped,
70+
],
71+
("google.protobuf", "FloatValue"): [
72+
FloatValue.from_dict,
73+
FloatValue.to_dict,
74+
FloatValue.from_wrapped,
75+
FloatValue.to_wrapped,
76+
],
77+
("google.protobuf", "DoubleValue"): [
78+
DoubleValue.from_dict,
79+
DoubleValue.to_dict,
80+
DoubleValue.from_wrapped,
81+
DoubleValue.to_wrapped,
82+
],
83+
("google.protobuf", "StringValue"): [
84+
StringValue.from_dict,
85+
StringValue.to_dict,
86+
StringValue.from_wrapped,
87+
StringValue.to_wrapped,
88+
],
89+
("google.protobuf", "BytesValue"): [
90+
BytesValue.from_dict,
91+
BytesValue.to_dict,
92+
BytesValue.from_wrapped,
93+
BytesValue.to_wrapped,
94+
],
3395
}
3496

3597
# A wrapped type is the type of a message that is automatically replaced by a known Python type.

src/betterproto2_compiler/known_types/google_values.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
import typing
2+
3+
import betterproto2
4+
15
from betterproto2_compiler.lib.google.protobuf import (
26
BoolValue as VanillaBoolValue,
7+
BytesValue as VanillaBytesValue,
8+
DoubleValue as VanillaDoubleValue,
9+
FloatValue as VanillaFloatValue,
10+
Int32Value as VanillaInt32Value,
11+
Int64Value as VanillaInt64Value,
312
StringValue as VanillaStringValue,
13+
UInt32Value as VanillaUInt32Value,
14+
UInt64Value as VanillaUInt64Value,
415
)
516

617

@@ -12,6 +23,165 @@ def from_wrapped(wrapped: bool) -> "BoolValue":
1223
def to_wrapped(self) -> bool:
1324
return self.value
1425

26+
@classmethod
27+
def from_dict(cls, value):
28+
if isinstance(value, bool):
29+
return BoolValue(value=value)
30+
return super().from_dict(value)
31+
32+
def to_dict(
33+
self,
34+
*,
35+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
36+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
37+
include_default_values: bool = False,
38+
) -> dict[str, typing.Any] | typing.Any:
39+
return self.value
40+
41+
42+
class Int32Value(VanillaInt32Value):
43+
@staticmethod
44+
def from_wrapped(wrapped: int) -> "Int32Value":
45+
return Int32Value(value=wrapped)
46+
47+
def to_wrapped(self) -> int:
48+
return self.value
49+
50+
@classmethod
51+
def from_dict(cls, value):
52+
if isinstance(value, int):
53+
return Int32Value(value=value)
54+
return super().from_dict(value)
55+
56+
def to_dict(
57+
self,
58+
*,
59+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
60+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
61+
include_default_values: bool = False,
62+
) -> dict[str, typing.Any] | typing.Any:
63+
return self.value
64+
65+
66+
class Int64Value(VanillaInt64Value):
67+
@staticmethod
68+
def from_wrapped(wrapped: int) -> "Int64Value":
69+
return Int64Value(value=wrapped)
70+
71+
def to_wrapped(self) -> int:
72+
return self.value
73+
74+
@classmethod
75+
def from_dict(cls, value):
76+
if isinstance(value, int):
77+
return Int64Value(value=value)
78+
return super().from_dict(value)
79+
80+
def to_dict(
81+
self,
82+
*,
83+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
84+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
85+
include_default_values: bool = False,
86+
) -> dict[str, typing.Any] | typing.Any:
87+
return self.value
88+
89+
90+
class UInt32Value(VanillaUInt32Value):
91+
@staticmethod
92+
def from_wrapped(wrapped: int) -> "UInt32Value":
93+
return UInt32Value(value=wrapped)
94+
95+
def to_wrapped(self) -> int:
96+
return self.value
97+
98+
@classmethod
99+
def from_dict(cls, value):
100+
if isinstance(value, int):
101+
return UInt32Value(value=value)
102+
return super().from_dict(value)
103+
104+
def to_dict(
105+
self,
106+
*,
107+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
108+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
109+
include_default_values: bool = False,
110+
) -> dict[str, typing.Any] | typing.Any:
111+
return self.value
112+
113+
114+
class UInt64Value(VanillaUInt64Value):
115+
@staticmethod
116+
def from_wrapped(wrapped: int) -> "UInt64Value":
117+
return UInt64Value(value=wrapped)
118+
119+
def to_wrapped(self) -> int:
120+
return self.value
121+
122+
@classmethod
123+
def from_dict(cls, value):
124+
if isinstance(value, int):
125+
return UInt64Value(value=value)
126+
return super().from_dict(value)
127+
128+
def to_dict(
129+
self,
130+
*,
131+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
132+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
133+
include_default_values: bool = False,
134+
) -> dict[str, typing.Any] | typing.Any:
135+
return self.value
136+
137+
138+
class FloatValue(VanillaFloatValue):
139+
@staticmethod
140+
def from_wrapped(wrapped: float) -> "FloatValue":
141+
return FloatValue(value=wrapped)
142+
143+
def to_wrapped(self) -> float:
144+
return self.value
145+
146+
@classmethod
147+
def from_dict(cls, value):
148+
if isinstance(value, float):
149+
return FloatValue(value=value)
150+
return super().from_dict(value)
151+
152+
def to_dict(
153+
self,
154+
*,
155+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
156+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
157+
include_default_values: bool = False,
158+
) -> dict[str, typing.Any] | typing.Any:
159+
return self.value
160+
161+
162+
class DoubleValue(VanillaDoubleValue):
163+
@staticmethod
164+
def from_wrapped(wrapped: float) -> "DoubleValue":
165+
return DoubleValue(value=wrapped)
166+
167+
def to_wrapped(self) -> float:
168+
return self.value
169+
170+
@classmethod
171+
def from_dict(cls, value):
172+
if isinstance(value, float):
173+
return DoubleValue(value=value)
174+
return super().from_dict(value)
175+
176+
def to_dict(
177+
self,
178+
*,
179+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
180+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
181+
include_default_values: bool = False,
182+
) -> dict[str, typing.Any] | typing.Any:
183+
return self.value
184+
15185

16186
class StringValue(VanillaStringValue):
17187
@staticmethod
@@ -20,3 +190,42 @@ def from_wrapped(wrapped: str) -> "StringValue":
20190

21191
def to_wrapped(self) -> str:
22192
return self.value
193+
194+
@classmethod
195+
def from_dict(cls, value):
196+
if isinstance(value, str):
197+
return StringValue(value=value)
198+
return super().from_dict(value)
199+
200+
def to_dict(
201+
self,
202+
*,
203+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
204+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
205+
include_default_values: bool = False,
206+
) -> dict[str, typing.Any] | typing.Any:
207+
return self.value
208+
209+
210+
class BytesValue(VanillaBytesValue):
211+
@staticmethod
212+
def from_wrapped(wrapped: bytes) -> "BytesValue":
213+
return BytesValue(value=wrapped)
214+
215+
def to_wrapped(self) -> bytes:
216+
return self.value
217+
218+
@classmethod
219+
def from_dict(cls, value):
220+
if isinstance(value, bytes):
221+
return BytesValue(value=value)
222+
return super().from_dict(value)
223+
224+
def to_dict(
225+
self,
226+
*,
227+
output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON,
228+
casing: betterproto2.Casing = betterproto2.Casing.CAMEL,
229+
include_default_values: bool = False,
230+
) -> dict[str, typing.Any] | typing.Any:
231+
return self.value

0 commit comments

Comments
 (0)