diff --git a/src/betterproto2_compiler/known_types/any.py b/src/betterproto2_compiler/known_types/any.py index 6952b52e..60ab1ea3 100644 --- a/src/betterproto2_compiler/known_types/any.py +++ b/src/betterproto2_compiler/known_types/any.py @@ -1,3 +1,5 @@ +import typing + import betterproto2 from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny @@ -18,19 +20,37 @@ def pack(self, message: betterproto2.Message, message_pool: "betterproto2.Messag self.type_url = message_pool.type_to_url[type(message)] self.value = bytes(message) - def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message: + def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message | None: """ Return the message packed inside the `Any` object. The target message type must be registered in the message pool, which is done automatically when the module defining the message type is imported. """ + if not self.type_url: + return None + message_pool = message_pool or default_message_pool - message_type = message_pool.url_to_type[self.type_url] + try: + message_type = message_pool.url_to_type[self.type_url] + except KeyError: + raise TypeError(f"Can't unpack unregistered type: {self.type_url}") return message_type().parse(self.value) - def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride] - # TOOO improve when dict is updated - return {"@type": self.type_url, "value": self.unpack().to_dict()} + def to_dict(self, **kwargs) -> dict[str, typing.Any]: + # TODO allow passing a message pool to `to_dict` + output: dict[str, typing.Any] = {"@type": self.type_url} + + value = self.unpack() + + if value is None: + return output + + if type(value).to_dict == betterproto2.Message.to_dict: + output.update(value.to_dict(**kwargs)) + else: + output["value"] = value.to_dict(**kwargs) + + return output diff --git a/src/betterproto2_compiler/lib/google/protobuf/__init__.py b/src/betterproto2_compiler/lib/google/protobuf/__init__.py index dbc4cf47..6a9991c6 100644 --- a/src/betterproto2_compiler/lib/google/protobuf/__init__.py +++ b/src/betterproto2_compiler/lib/google/protobuf/__init__.py @@ -724,35 +724,6 @@ class Any(betterproto2.Message): Must be a valid serialized protocol buffer of the above specified type. """ - def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None: - """ - Pack the given message in the `Any` object. - - The message type must be registered in the message pool, which is done automatically when the module defining - the message type is imported. - """ - message_pool = message_pool or default_message_pool - - self.type_url = message_pool.type_to_url[type(message)] - self.value = bytes(message) - - def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message: - """ - Return the message packed inside the `Any` object. - - The target message type must be registered in the message pool, which is done automatically when the module - defining the message type is imported. - """ - message_pool = message_pool or default_message_pool - - message_type = message_pool.url_to_type[self.type_url] - - return message_type().parse(self.value) - - def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride] - # TOOO improve when dict is updated - return {"@type": self.type_url, "value": self.unpack().to_dict()} - default_message_pool.register_message("google.protobuf", "Any", Any) diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index c73e7868..f7d12d24 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -22,6 +22,7 @@ import builtins import datetime import warnings from collections.abc import AsyncIterable, AsyncIterator, Iterable +import typing from typing import TYPE_CHECKING {% if output_file.settings.pydantic_dataclasses %}