diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index 9f6b954d..9e4901db 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -5,7 +5,7 @@ TYPE_CHECKING, ) -from betterproto2_compiler.lib.google import protobuf as google_protobuf +from betterproto2_compiler.known_types import WRAPPED_TYPES from betterproto2_compiler.settings import Settings from ..casing import safe_snake_case @@ -14,18 +14,6 @@ if TYPE_CHECKING: from ..plugin.models import PluginRequestCompiler -WRAPPER_TYPES: dict[str, type] = { - ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, - ".google.protobuf.FloatValue": google_protobuf.FloatValue, - ".google.protobuf.Int32Value": google_protobuf.Int32Value, - ".google.protobuf.Int64Value": google_protobuf.Int64Value, - ".google.protobuf.UInt32Value": google_protobuf.UInt32Value, - ".google.protobuf.UInt64Value": google_protobuf.UInt64Value, - ".google.protobuf.BoolValue": google_protobuf.BoolValue, - ".google.protobuf.StringValue": google_protobuf.StringValue, - ".google.protobuf.BytesValue": google_protobuf.BytesValue, -} - def parse_source_type_name(field_type_name: str, request: PluginRequestCompiler) -> tuple[str, str]: """ @@ -73,26 +61,18 @@ def get_type_reference( imports: set, source_type: str, request: PluginRequestCompiler, - unwrap: bool = True, + wrap: bool = True, settings: Settings, ) -> str: """ Return a Python type name for a proto type reference. Adds the import if necessary. Unwraps well known type if required. """ - if unwrap: - if source_type in WRAPPER_TYPES: - wrapped_type = type(WRAPPER_TYPES[source_type]().value) - return f"{wrapped_type.__name__} | None" - - if source_type == ".google.protobuf.Duration": - return "datetime.timedelta" - - elif source_type == ".google.protobuf.Timestamp": - return "datetime.datetime" - source_package, source_type = parse_source_type_name(source_type, request) + if wrap and (source_package, source_type) in WRAPPED_TYPES: + return WRAPPED_TYPES[(source_package, source_type)] + current_package: list[str] = package.split(".") if package else [] py_package: list[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) diff --git a/src/betterproto2_compiler/known_types/__init__.py b/src/betterproto2_compiler/known_types/__init__.py index 4a56d5f5..00a8c3f4 100644 --- a/src/betterproto2_compiler/known_types/__init__.py +++ b/src/betterproto2_compiler/known_types/__init__.py @@ -2,6 +2,17 @@ from .any import Any from .duration import Duration +from .google_values import ( + BoolValue, + BytesValue, + DoubleValue, + FloatValue, + Int32Value, + Int64Value, + StringValue, + UInt32Value, + UInt64Value, +) from .timestamp import Timestamp # For each (package, message name), lists the methods that should be added to the message definition. @@ -9,6 +20,84 @@ # to the template file: they will automatically be removed if not necessary. KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = { ("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict], - ("google.protobuf", "Timestamp"): [Timestamp.from_datetime, Timestamp.to_datetime, Timestamp.timestamp_to_json], - ("google.protobuf", "Duration"): [Duration.from_timedelta, Duration.to_timedelta, Duration.delta_to_json], + ("google.protobuf", "Timestamp"): [ + Timestamp.from_datetime, + Timestamp.to_datetime, + Timestamp.timestamp_to_json, + Timestamp.from_dict, + Timestamp.to_dict, + Timestamp.from_wrapped, + Timestamp.to_wrapped, + ], + ("google.protobuf", "Duration"): [ + Duration.from_timedelta, + Duration.to_timedelta, + Duration.delta_to_json, + Duration.from_dict, + Duration.to_dict, + Duration.from_wrapped, + Duration.to_wrapped, + ], + ("google.protobuf", "BoolValue"): [ + BoolValue.from_dict, + BoolValue.to_dict, + BoolValue.from_wrapped, + BoolValue.to_wrapped, + ], + ("google.protobuf", "Int32Value"): [ + Int32Value.from_dict, + Int32Value.to_dict, + Int32Value.from_wrapped, + Int32Value.to_wrapped, + ], + ("google.protobuf", "Int64Value"): [ + Int64Value.from_dict, + Int64Value.to_dict, + Int64Value.from_wrapped, + Int64Value.to_wrapped, + ], + ("google.protobuf", "UInt32Value"): [ + UInt32Value.from_dict, + UInt32Value.to_dict, + UInt32Value.from_wrapped, + UInt32Value.to_wrapped, + ], + ("google.protobuf", "UInt64Value"): [ + UInt64Value.from_dict, + UInt64Value.to_dict, + UInt64Value.from_wrapped, + UInt64Value.to_wrapped, + ], + ("google.protobuf", "FloatValue"): [ + FloatValue.from_dict, + FloatValue.to_dict, + FloatValue.from_wrapped, + FloatValue.to_wrapped, + ], + ("google.protobuf", "DoubleValue"): [ + DoubleValue.from_dict, + DoubleValue.to_dict, + DoubleValue.from_wrapped, + DoubleValue.to_wrapped, + ], + ("google.protobuf", "StringValue"): [ + StringValue.from_dict, + StringValue.to_dict, + StringValue.from_wrapped, + StringValue.to_wrapped, + ], + ("google.protobuf", "BytesValue"): [ + BytesValue.from_dict, + BytesValue.to_dict, + BytesValue.from_wrapped, + BytesValue.to_wrapped, + ], +} + +# A wrapped type is the type of a message that is automatically replaced by a known Python type. +WRAPPED_TYPES: dict[tuple[str, str], str] = { + ("google.protobuf", "BoolValue"): "bool", + ("google.protobuf", "StringValue"): "str", + ("google.protobuf", "Timestamp"): "datetime.datetime", + ("google.protobuf", "Duration"): "datetime.timedelta", } diff --git a/src/betterproto2_compiler/known_types/duration.py b/src/betterproto2_compiler/known_types/duration.py index 14e7d81f..bd1705dc 100644 --- a/src/betterproto2_compiler/known_types/duration.py +++ b/src/betterproto2_compiler/known_types/duration.py @@ -1,4 +1,8 @@ import datetime +import re +import typing + +import betterproto2 from betterproto2_compiler.lib.google.protobuf import Duration as VanillaDuration @@ -23,3 +27,44 @@ def delta_to_json(delta: datetime.timedelta) -> str: while len(parts[1]) not in (3, 6, 9): parts[1] = f"{parts[1]}0" return f"{'.'.join(parts)}s" + + # TODO typing + @classmethod + def from_dict(cls, value): + if isinstance(value, str): + if not re.match(r"^\d+(\.\d+)?s$", value): + raise ValueError(f"Invalid duration string: {value}") + + seconds = float(value[:-1]) + return Duration(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9)) + + return super().from_dict(value) + + # TODO typing + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + # If the output format is PYTHON, we should have kept the wrapped type without building the real class + assert output_format == betterproto2.OutputFormat.PROTO_JSON + + assert 0 <= self.nanos < 1e9 + + if self.nanos == 0: + return f"{self.seconds}s" + + nanos = f"{self.nanos:09d}".rstrip("0") + if len(nanos) < 3: + nanos += "0" * (3 - len(nanos)) + + return f"{self.seconds}.{nanos}s" + + @staticmethod + def from_wrapped(wrapped: datetime.timedelta) -> "Duration": + return Duration.from_timedelta(wrapped) + + def to_wrapped(self) -> datetime.timedelta: + return self.to_timedelta() diff --git a/src/betterproto2_compiler/known_types/google_values.py b/src/betterproto2_compiler/known_types/google_values.py new file mode 100644 index 00000000..3b1d4300 --- /dev/null +++ b/src/betterproto2_compiler/known_types/google_values.py @@ -0,0 +1,231 @@ +import typing + +import betterproto2 + +from betterproto2_compiler.lib.google.protobuf import ( + BoolValue as VanillaBoolValue, + BytesValue as VanillaBytesValue, + DoubleValue as VanillaDoubleValue, + FloatValue as VanillaFloatValue, + Int32Value as VanillaInt32Value, + Int64Value as VanillaInt64Value, + StringValue as VanillaStringValue, + UInt32Value as VanillaUInt32Value, + UInt64Value as VanillaUInt64Value, +) + + +class BoolValue(VanillaBoolValue): + @staticmethod + def from_wrapped(wrapped: bool) -> "BoolValue": + return BoolValue(value=wrapped) + + def to_wrapped(self) -> bool: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, bool): + return BoolValue(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class Int32Value(VanillaInt32Value): + @staticmethod + def from_wrapped(wrapped: int) -> "Int32Value": + return Int32Value(value=wrapped) + + def to_wrapped(self) -> int: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, int): + return Int32Value(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class Int64Value(VanillaInt64Value): + @staticmethod + def from_wrapped(wrapped: int) -> "Int64Value": + return Int64Value(value=wrapped) + + def to_wrapped(self) -> int: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, int): + return Int64Value(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class UInt32Value(VanillaUInt32Value): + @staticmethod + def from_wrapped(wrapped: int) -> "UInt32Value": + return UInt32Value(value=wrapped) + + def to_wrapped(self) -> int: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, int): + return UInt32Value(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class UInt64Value(VanillaUInt64Value): + @staticmethod + def from_wrapped(wrapped: int) -> "UInt64Value": + return UInt64Value(value=wrapped) + + def to_wrapped(self) -> int: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, int): + return UInt64Value(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class FloatValue(VanillaFloatValue): + @staticmethod + def from_wrapped(wrapped: float) -> "FloatValue": + return FloatValue(value=wrapped) + + def to_wrapped(self) -> float: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, float): + return FloatValue(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class DoubleValue(VanillaDoubleValue): + @staticmethod + def from_wrapped(wrapped: float) -> "DoubleValue": + return DoubleValue(value=wrapped) + + def to_wrapped(self) -> float: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, float): + return DoubleValue(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class StringValue(VanillaStringValue): + @staticmethod + def from_wrapped(wrapped: str) -> "StringValue": + return StringValue(value=wrapped) + + def to_wrapped(self) -> str: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, str): + return StringValue(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value + + +class BytesValue(VanillaBytesValue): + @staticmethod + def from_wrapped(wrapped: bytes) -> "BytesValue": + return BytesValue(value=wrapped) + + def to_wrapped(self) -> bytes: + return self.value + + @classmethod + def from_dict(cls, value): + if isinstance(value, bytes): + return BytesValue(value=value) + return super().from_dict(value) + + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + return self.value diff --git a/src/betterproto2_compiler/known_types/timestamp.py b/src/betterproto2_compiler/known_types/timestamp.py index e1b8f04f..d1e168be 100644 --- a/src/betterproto2_compiler/known_types/timestamp.py +++ b/src/betterproto2_compiler/known_types/timestamp.py @@ -1,4 +1,8 @@ import datetime +import typing + +import betterproto2 +import dateutil.parser from betterproto2_compiler.lib.google.protobuf import Timestamp as VanillaTimestamp @@ -6,6 +10,11 @@ class Timestamp(VanillaTimestamp): @classmethod def from_datetime(cls, dt: datetime.datetime) -> "Timestamp": + if not dt.tzinfo: + raise ValueError("datetime must be timezone aware") + + dt = dt.astimezone(datetime.timezone.utc) + # manual epoch offset calulation to avoid rounding errors, # to support negative timestamps (before 1970) and skirt # around datetime bugs (apparently 0 isn't a year in [0, 9999]??) @@ -43,3 +52,33 @@ def timestamp_to_json(dt: datetime.datetime) -> str: return f"{result}.{int(nanos // 1e3):06d}Z" # Serialize 9 fractional digits. return f"{result}.{nanos:09d}" + + # TODO typing + @classmethod + def from_dict(cls, value): + if isinstance(value, str): + dt = dateutil.parser.isoparse(value) + dt = dt.astimezone(datetime.timezone.utc) + return Timestamp.from_datetime(dt) + + return super().from_dict(value) + + # TODO typing + def to_dict( + self, + *, + output_format: betterproto2.OutputFormat = betterproto2.OutputFormat.PROTO_JSON, + casing: betterproto2.Casing = betterproto2.Casing.CAMEL, + include_default_values: bool = False, + ) -> dict[str, typing.Any] | typing.Any: + # If the output format is PYTHON, we should have kept the wraped type without building the real class + assert output_format == betterproto2.OutputFormat.PROTO_JSON + + return Timestamp.timestamp_to_json(self.to_datetime()) + + @staticmethod + def from_wrapped(wrapped: datetime.datetime) -> "Timestamp": + return Timestamp.from_datetime(wrapped) + + def to_wrapped(self) -> datetime.datetime: + return self.to_datetime() diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index af409dd6..7855b3b2 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -26,14 +26,12 @@ import builtins import inspect -import re from collections.abc import Iterator from dataclasses import ( dataclass, field, ) -import betterproto2 from betterproto2 import unwrap from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name @@ -43,7 +41,7 @@ pythonize_field_name, pythonize_method_name, ) -from betterproto2_compiler.known_types import KNOWN_METHODS +from betterproto2_compiler.known_types import KNOWN_METHODS, WRAPPED_TYPES from betterproto2_compiler.lib.google.protobuf import ( DescriptorProto, EnumDescriptorProto, @@ -318,8 +316,24 @@ def get_field_string(self) -> str: @property def betterproto_field_args(self) -> list[str]: args = [] - if self.field_wraps: - args.append(f"wraps={self.field_wraps}") + + if self.field_type == FieldDescriptorProtoType.TYPE_MESSAGE: + type_package, type_name = parse_source_type_name(self.proto_obj.type_name, self.output_file.parent_request) + + if (type_package, type_name) in WRAPPED_TYPES: + unwrap_type = get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports_end, + source_type=self.proto_obj.type_name, + request=self.output_file.parent_request, + settings=self.output_file.settings, + wrap=False, + ) + + # Without the lambda function, the type is evaluated right away, which fails since the corresponding + # import is placed at the end of the file to avoid circular imports. + args.append(f"unwrap=lambda: {unwrap_type}") + if self.optional: args.append("optional=True") elif self.repeated: @@ -338,16 +352,6 @@ def use_builtins(self) -> bool: self.py_type == self.py_name and self.py_name in dir(builtins) ) - @property - def field_wraps(self) -> str | None: - """Returns betterproto wrapped field type or None.""" - match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto2, wrapped_type): - return f"betterproto2.{wrapped_type}" - return None - @property def repeated(self) -> bool: return self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED @@ -405,6 +409,7 @@ def py_type(self) -> str: @property def annotation(self) -> str: py_type = self.py_type + if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: @@ -581,7 +586,7 @@ def py_input_message_type(self) -> str: imports=self.parent.output_file.imports_end, source_type=self.proto_obj.input_type, request=self.parent.output_file.parent_request, - unwrap=False, + wrap=False, settings=self.parent.output_file.settings, ) @@ -608,7 +613,7 @@ def py_output_message_type(self) -> str: imports=self.parent.output_file.imports_end, source_type=self.proto_obj.output_type, request=self.parent.output_file.parent_request, - unwrap=False, + wrap=False, settings=self.parent.output_file.settings, ) diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index 7d5f7cb0..41c0882e 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -18,8 +18,10 @@ __all__ = ( {%- endfor -%} ) +import re import builtins import datetime +import dateutil.parser import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator import typing