Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Commit 7e76561

Browse files
Improved Any support (#33)
* Generate additional methods for the Any type * Fix protobuf version * Fix type checking
1 parent b4e460f commit 7e76561

File tree

6 files changed

+140
-86
lines changed

6 files changed

+140
-86
lines changed

poetry.lock

Lines changed: 68 additions & 75 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ typing-extensions = "^4.7.1"
2222

2323
[tool.poetry.group.dev.dependencies]
2424
pre-commit = "^2.17.0"
25-
grpcio-tools = "^1.69.0"
25+
grpcio-tools = "^1.54.2"
2626
mkdocs-material = {version = "^9.5.49", python = ">=3.10"}
2727
mkdocstrings = {version = "^0.27.0", python = ">=3.10", extras = ["python"]}
2828
poethepoet = ">=0.9.0"
@@ -31,6 +31,7 @@ ipykernel = "^6.29.5"
3131

3232
[tool.poetry.group.test.dependencies]
3333
pytest = "^6.2.5"
34+
protobuf = "^4"
3435

3536
[tool.poetry.scripts]
3637
protoc-gen-python_betterproto2 = "betterproto2_compiler.plugin:main"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from collections.abc import Callable
2+
3+
from .any import Any
4+
5+
# For each (package, message name), lists the methods that should be added to the message definition.
6+
# The source code of the method is read from the `known_types` folder. If imports are needed, they can be directly added
7+
# to the template file: they will automatically be removed if not necessary.
8+
KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = {
9+
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict],
10+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import betterproto2
2+
from betterproto2.lib.std.google.protobuf import Any as VanillaAny
3+
4+
5+
class Any(VanillaAny):
6+
def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
7+
"""
8+
Pack the given message in the `Any` object.
9+
10+
The message type must be registered in the message pool, which is done automatically when the module defining
11+
the message type is imported.
12+
"""
13+
message_pool = message_pool or betterproto2.default_message_pool
14+
15+
self.type_url = message_pool.type_to_url[type(message)]
16+
self.value = bytes(message)
17+
18+
def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message:
19+
"""
20+
Return the message packed inside the `Any` object.
21+
22+
The target message type must be registered in the message pool, which is done automatically when the module
23+
defining the message type is imported.
24+
"""
25+
message_pool = message_pool or betterproto2.default_message_pool
26+
27+
message_type = message_pool.url_to_type[self.type_url]
28+
29+
return message_type().parse(self.value)
30+
31+
def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride]
32+
# TOOO improve when dict is updated
33+
return {"@type": self.type_url, "value": self.unpack().to_dict()}

src/betterproto2_compiler/plugin/models.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626

2727
import builtins
28+
import inspect
2829
import re
2930
from collections.abc import Iterator
3031
from dataclasses import (
@@ -47,23 +48,18 @@
4748
ServiceDescriptorProto,
4849
)
4950

51+
from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name
5052
from betterproto2_compiler.compile.naming import (
5153
pythonize_class_name,
54+
pythonize_enum_member_name,
5255
pythonize_field_name,
5356
pythonize_method_name,
5457
)
58+
from betterproto2_compiler.known_types import KNOWN_METHODS
5559
from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
60+
from betterproto2_compiler.plugin.typing_compiler import TypingCompiler
5661
from betterproto2_compiler.settings import Settings
5762

58-
from ..compile.importing import get_type_reference, parse_source_type_name
59-
from ..compile.naming import (
60-
pythonize_class_name,
61-
pythonize_enum_member_name,
62-
pythonize_field_name,
63-
pythonize_method_name,
64-
)
65-
from .typing_compiler import TypingCompiler
66-
6763
# Organize proto types into categories
6864
PROTO_FLOAT_TYPES = (
6965
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
@@ -264,6 +260,19 @@ def has_message_field(self) -> bool:
264260
if isinstance(field.proto_obj, FieldDescriptorProto)
265261
)
266262

263+
@property
264+
def custom_methods(self) -> list[str]:
265+
"""
266+
Return a list of the custom methods.
267+
"""
268+
methods_source: list[str] = []
269+
270+
for method in KNOWN_METHODS.get((self.source_file.package, self.py_name), []):
271+
source = inspect.getsource(method)
272+
methods_source.append(source.strip())
273+
274+
return methods_source
275+
267276

268277
def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool:
269278
"""True if proto_field_obj is a map, otherwise False."""

src/betterproto2_compiler/templates/template.py.j2

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class {{ message.py_name }}(betterproto2.Message):
3535
{% if message.comment or message.oneofs %}
3636
"""
3737
{{ message.comment | indent(4) }}
38-
3938
{% if message.oneofs %}
39+
4040
Oneofs:
4141
{% for oneof in message.oneofs %}
4242
- {{ oneof.name }}: {{ oneof.comment | indent(12) }}
@@ -52,6 +52,7 @@ class {{ message.py_name }}(betterproto2.Message):
5252
{{ field.comment | indent(4) }}
5353
"""
5454
{% endif %}
55+
5556
{% endfor %}
5657

5758
{% if not message.fields %}
@@ -74,7 +75,14 @@ class {{ message.py_name }}(betterproto2.Message):
7475
@model_validator(mode='after')
7576
def check_oneof(cls, values):
7677
return cls._validate_field_groups(values)
78+
7779
{% endif %}
80+
{% for method_source in message.custom_methods %}
81+
{{ method_source }}
82+
{% endfor %}
83+
84+
betterproto2.default_message_pool.register_message("{{ output_file.package }}", "{{ message.proto_name }}", {{ message.py_name }})
85+
7886

7987
{% endfor %}
8088
{% for _, service in output_file.services|dictsort(by="key") %}

0 commit comments

Comments
 (0)