Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 68 additions & 75 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ typing-extensions = "^4.7.1"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.17.0"
grpcio-tools = "^1.69.0"
grpcio-tools = "^1.54.2"
mkdocs-material = {version = "^9.5.49", python = ">=3.10"}
mkdocstrings = {version = "^0.27.0", python = ">=3.10", extras = ["python"]}
poethepoet = ">=0.9.0"
Expand All @@ -31,6 +31,7 @@ ipykernel = "^6.29.5"

[tool.poetry.group.test.dependencies]
pytest = "^6.2.5"
protobuf = "^4"

[tool.poetry.scripts]
protoc-gen-python_betterproto2 = "betterproto2_compiler.plugin:main"
Expand Down
10 changes: 10 additions & 0 deletions src/betterproto2_compiler/known_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from collections.abc import Callable

from .any import Any

# For each (package, message name), lists the methods that should be added to the message definition.
# The source code of the method is read from the `known_types` folder. If imports are needed, they can be directly added
# to the template file: they will automatically be removed if not necessary.
KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = {
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict],
}
33 changes: 33 additions & 0 deletions src/betterproto2_compiler/known_types/any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import betterproto2
from betterproto2.lib.std.google.protobuf import Any as VanillaAny


class Any(VanillaAny):
def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
"""
Pack the given message in the `Any` object.

The message type must be registered in the message pool, which is done automatically when the module defining
the message type is imported.
"""
message_pool = message_pool or betterproto2.default_message_pool

self.type_url = message_pool.type_to_url[type(message)]
self.value = bytes(message)

def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message:
"""
Return the message packed inside the `Any` object.

The target message type must be registered in the message pool, which is done automatically when the module
defining the message type is imported.
"""
message_pool = message_pool or betterproto2.default_message_pool

message_type = message_pool.url_to_type[self.type_url]

return message_type().parse(self.value)

def to_dict(self) -> dict: # pyright: ignore [reportIncompatibleMethodOverride]
# TOOO improve when dict is updated
return {"@type": self.type_url, "value": self.unpack().to_dict()}
27 changes: 18 additions & 9 deletions src/betterproto2_compiler/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""

import builtins
import inspect
import re
from collections.abc import Iterator
from dataclasses import (
Expand All @@ -47,23 +48,18 @@
ServiceDescriptorProto,
)

from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name
from betterproto2_compiler.compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto2_compiler.known_types import KNOWN_METHODS
from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
from betterproto2_compiler.plugin.typing_compiler import TypingCompiler
from betterproto2_compiler.settings import Settings

from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import TypingCompiler

# Organize proto types into categories
PROTO_FLOAT_TYPES = (
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
Expand Down Expand Up @@ -264,6 +260,19 @@ def has_message_field(self) -> bool:
if isinstance(field.proto_obj, FieldDescriptorProto)
)

@property
def custom_methods(self) -> list[str]:
"""
Return a list of the custom methods.
"""
methods_source: list[str] = []

for method in KNOWN_METHODS.get((self.source_file.package, self.py_name), []):
source = inspect.getsource(method)
methods_source.append(source.strip())

return methods_source


def is_map(proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto) -> bool:
"""True if proto_field_obj is a map, otherwise False."""
Expand Down
10 changes: 9 additions & 1 deletion src/betterproto2_compiler/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class {{ message.py_name }}(betterproto2.Message):
{% if message.comment or message.oneofs %}
"""
{{ message.comment | indent(4) }}

{% if message.oneofs %}

Oneofs:
{% for oneof in message.oneofs %}
- {{ oneof.name }}: {{ oneof.comment | indent(12) }}
Expand All @@ -52,6 +52,7 @@ class {{ message.py_name }}(betterproto2.Message):
{{ field.comment | indent(4) }}
"""
{% endif %}

{% endfor %}

{% if not message.fields %}
Expand All @@ -74,7 +75,14 @@ class {{ message.py_name }}(betterproto2.Message):
@model_validator(mode='after')
def check_oneof(cls, values):
return cls._validate_field_groups(values)

{% endif %}
{% for method_source in message.custom_methods %}
{{ method_source }}
{% endfor %}

betterproto2.default_message_pool.register_message("{{ output_file.package }}", "{{ message.proto_name }}", {{ message.py_name }})


{% endfor %}
{% for _, service in output_file.services|dictsort(by="key") %}
Expand Down
Loading