Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/betterproto2_compiler/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from betterproto2.lib.google import protobuf as google_protobuf

from betterproto2_compiler.settings import Settings

from ..casing import safe_snake_case
from .naming import pythonize_class_name

if TYPE_CHECKING:
from ..plugin.models import PluginRequestCompiler
from ..plugin.typing_compiler import TypingCompiler

WRAPPER_TYPES: dict[str, type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
Expand Down Expand Up @@ -72,10 +73,9 @@ def get_type_reference(
package: str,
imports: set,
source_type: str,
typing_compiler: TypingCompiler,
request: PluginRequestCompiler,
unwrap: bool = True,
pydantic: bool = False,
settings: Settings,
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
Expand All @@ -84,7 +84,7 @@ def get_type_reference(
if unwrap:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return typing_compiler.optional(wrapped_type.__name__)
return settings.typing_compiler.optional(wrapped_type.__name__)

if source_type == ".google.protobuf.Duration":
return "datetime.timedelta"
Expand All @@ -101,7 +101,7 @@ def get_type_reference(
compiling_google_protobuf = current_package == ["google", "protobuf"]
importing_google_protobuf = py_package == ["google", "protobuf"]
if importing_google_protobuf and not compiling_google_protobuf:
py_package = ["betterproto2", "lib"] + (["pydantic"] if pydantic else []) + py_package
py_package = ["betterproto2", "lib"] + (["pydantic"] if settings.pydantic_dataclasses else []) + py_package

if py_package[:1] == ["betterproto2"]:
return reference_absolute(imports, py_package, py_type)
Expand Down
19 changes: 7 additions & 12 deletions src/betterproto2_compiler/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
pythonize_method_name,
)
from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
from betterproto2_compiler.settings import Settings

from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
Expand All @@ -61,10 +62,7 @@
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)
from .typing_compiler import TypingCompiler

# Organize proto types into categories
PROTO_FLOAT_TYPES = (
Expand Down Expand Up @@ -195,9 +193,9 @@ class OutputTemplate:
messages: dict[str, "MessageCompiler"] = field(default_factory=dict)
enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
services: dict[str, "ServiceCompiler"] = field(default_factory=dict)
pydantic_dataclasses: bool = False
output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)

settings: Settings

@property
def package(self) -> str:
Expand Down Expand Up @@ -395,9 +393,8 @@ def py_type(self) -> str:
package=self.output_file.package,
imports=self.output_file.imports_end,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
request=self.output_file.parent_request,
pydantic=self.output_file.pydantic_dataclasses,
settings=self.output_file.settings,
)
else:
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
Expand Down Expand Up @@ -582,10 +579,9 @@ def py_input_message_type(self) -> str:
package=self.parent.output_file.package,
imports=self.parent.output_file.imports_end,
source_type=self.proto_obj.input_type,
typing_compiler=self.parent.output_file.typing_compiler,
request=self.parent.output_file.parent_request,
unwrap=False,
pydantic=self.parent.output_file.pydantic_dataclasses,
settings=self.parent.output_file.settings,
)

@property
Expand Down Expand Up @@ -621,10 +617,9 @@ def py_output_message_type(self) -> str:
package=self.parent.output_file.package,
imports=self.parent.output_file.imports_end,
source_type=self.proto_obj.output_type,
typing_compiler=self.parent.output_file.typing_compiler,
request=self.parent.output_file.parent_request,
unwrap=False,
pydantic=self.parent.output_file.pydantic_dataclasses,
settings=self.parent.output_file.settings,
)

@property
Expand Down
55 changes: 31 additions & 24 deletions src/betterproto2_compiler/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CodeGeneratorResponseFeature,
CodeGeneratorResponseFile,
)
from betterproto2_compiler.settings import Settings

from .compiler import outputfile_compiler
from .models import (
Expand Down Expand Up @@ -64,11 +65,35 @@ def _traverse(
yield from _traverse([4], proto_file.message_type)


def get_settings(plugin_options: list[str]) -> Settings:
# Gather any typing generation options.
typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]

if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")

# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
typing_compiler = NoTyping310TypingCompiler()
else:
raise ValueError("Invalid typing option provided")

return Settings(
typing_compiler=typing_compiler,
pydantic_dataclasses="pydantic_dataclasses" in plugin_options,
)


def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response = CodeGeneratorResponse()
response = CodeGeneratorResponse(supported_features=CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL)

plugin_options = request.parameter.split(",") if request.parameter else []
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
settings = get_settings(plugin_options)

request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
Expand All @@ -77,8 +102,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
request_data.output_packages[output_package_name] = OutputTemplate(
parent_request=request_data,
package_proto_obj=proto_file,
parent_request=request_data, package_proto_obj=proto_file, settings=settings
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)
Expand All @@ -88,23 +112,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False

if "pydantic_dataclasses" in plugin_options:
request_data.output_packages[output_package_name].pydantic_dataclasses = True

# Gather any typing generation options.
typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]

if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")
# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
request_data.output_packages[output_package_name].typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
request_data.output_packages[output_package_name].typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
request_data.output_packages[output_package_name].typing_compiler = NoTyping310TypingCompiler()

# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
Expand Down Expand Up @@ -199,7 +206,7 @@ def read_protobuf_type(
message=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
typing_compiler=output_package.settings.typing_compiler,
)
)
elif is_oneof(field):
Expand All @@ -209,7 +216,7 @@ def read_protobuf_type(
message=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
typing_compiler=output_package.settings.typing_compiler,
)
)
else:
Expand All @@ -219,7 +226,7 @@ def read_protobuf_type(
message=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
typing_compiler=output_package.settings.typing_compiler,
)
)

Expand Down
9 changes: 9 additions & 0 deletions src/betterproto2_compiler/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass

from .plugin.typing_compiler import TypingCompiler


@dataclass
class Settings:
pydantic_dataclasses: bool
typing_compiler: TypingCompiler
6 changes: 3 additions & 3 deletions src/betterproto2_compiler/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import builtins
import datetime
import warnings

{% if output_file.pydantic_dataclasses %}
{% if output_file.settings.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
from pydantic import model_validator
{%- else -%}
from dataclasses import dataclass
{% endif %}

{% set typing_imports = output_file.typing_compiler.imports() %}
{% set typing_imports = output_file.settings.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
{% for line in output_file.settings.typing_compiler.import_lines() %}
{{ line }}
{% endfor %}
{% endif %}
Expand Down
24 changes: 12 additions & 12 deletions src/betterproto2_compiler/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class {{ enum.py_name }}(betterproto2.Enum):

{% endfor %}

{% if output_file.pydantic_dataclasses %}
{% if output_file.settings.pydantic_dataclasses %}
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
Expand All @@ -26,7 +26,7 @@ class {{ enum.py_name }}(betterproto2.Enum):

{% endfor %}
{% for _, message in output_file.messages|dictsort(by="key") %}
{% if output_file.pydantic_dataclasses %}
{% if output_file.settings.pydantic_dataclasses %}
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
{% else %}
@dataclass(eq=False, repr=False)
Expand Down Expand Up @@ -70,7 +70,7 @@ class {{ message.py_name }}(betterproto2.Message):
{% endfor %}
{% endif %}

{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
{% if output_file.settings.pydantic_dataclasses and message.has_oneof_fields %}
@model_validator(mode='after')
def check_oneof(cls, values):
return cls._validate_field_groups(values)
Expand All @@ -92,20 +92,20 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
{%- if not method.client_streaming -%}
, {{ method.py_input_message_param }}:
{%- if method.is_input_msg_empty -%}
"{{ method.py_input_message_type }} | None" = None
"{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
{%- else -%}
"{{ method.py_input_message_type }}"
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
, {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
{%- endif -%}
,
*
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: "{{ output_file.typing_compiler.optional("Deadline") }}" = None
, metadata: "{{ output_file.typing_compiler.optional("MetadataLike") }}" = None
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
, timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None
, deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None
, metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None
) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
{% if method.comment %}
"""
{{ method.comment | indent(8) }}
Expand Down Expand Up @@ -194,9 +194,9 @@ class {{ service.py_name }}Base(ServiceBase):
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
, {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
{%- endif -%}
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
"""
{{ method.comment | indent(8) }}
Expand Down Expand Up @@ -227,7 +227,7 @@ class {{ service.py_name }}Base(ServiceBase):

{% endfor %}

def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
Expand Down
Loading