Skip to content
This repository was archived by the owner on Jun 9, 2025. It is now read-only.

Commit 5cc48fb

Browse files
Create and use Settings class (#27)
1 parent 19228ff commit 5cc48fb

File tree

6 files changed

+67
-56
lines changed

6 files changed

+67
-56
lines changed

src/betterproto2_compiler/compile/importing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
from betterproto2.lib.google import protobuf as google_protobuf
99

10+
from betterproto2_compiler.settings import Settings
11+
1012
from ..casing import safe_snake_case
1113
from .naming import pythonize_class_name
1214

1315
if TYPE_CHECKING:
1416
from ..plugin.models import PluginRequestCompiler
15-
from ..plugin.typing_compiler import TypingCompiler
1617

1718
WRAPPER_TYPES: dict[str, type] = {
1819
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
@@ -72,10 +73,9 @@ def get_type_reference(
7273
package: str,
7374
imports: set,
7475
source_type: str,
75-
typing_compiler: TypingCompiler,
7676
request: PluginRequestCompiler,
7777
unwrap: bool = True,
78-
pydantic: bool = False,
78+
settings: Settings,
7979
) -> str:
8080
"""
8181
Return a Python type name for a proto type reference. Adds the import if
@@ -84,7 +84,7 @@ def get_type_reference(
8484
if unwrap:
8585
if source_type in WRAPPER_TYPES:
8686
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
87-
return typing_compiler.optional(wrapped_type.__name__)
87+
return settings.typing_compiler.optional(wrapped_type.__name__)
8888

8989
if source_type == ".google.protobuf.Duration":
9090
return "datetime.timedelta"
@@ -101,7 +101,7 @@ def get_type_reference(
101101
compiling_google_protobuf = current_package == ["google", "protobuf"]
102102
importing_google_protobuf = py_package == ["google", "protobuf"]
103103
if importing_google_protobuf and not compiling_google_protobuf:
104-
py_package = ["betterproto2", "lib"] + (["pydantic"] if pydantic else []) + py_package
104+
py_package = ["betterproto2", "lib"] + (["pydantic"] if settings.pydantic_dataclasses else []) + py_package
105105

106106
if py_package[:1] == ["betterproto2"]:
107107
return reference_absolute(imports, py_package, py_type)

src/betterproto2_compiler/plugin/models.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
pythonize_method_name,
5454
)
5555
from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest
56+
from betterproto2_compiler.settings import Settings
5657

5758
from ..compile.importing import get_type_reference, parse_source_type_name
5859
from ..compile.naming import (
@@ -61,10 +62,7 @@
6162
pythonize_field_name,
6263
pythonize_method_name,
6364
)
64-
from .typing_compiler import (
65-
DirectImportTypingCompiler,
66-
TypingCompiler,
67-
)
65+
from .typing_compiler import TypingCompiler
6866

6967
# Organize proto types into categories
7068
PROTO_FLOAT_TYPES = (
@@ -195,9 +193,9 @@ class OutputTemplate:
195193
messages: dict[str, "MessageCompiler"] = field(default_factory=dict)
196194
enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
197195
services: dict[str, "ServiceCompiler"] = field(default_factory=dict)
198-
pydantic_dataclasses: bool = False
199196
output: bool = True
200-
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
197+
198+
settings: Settings
201199

202200
@property
203201
def package(self) -> str:
@@ -395,9 +393,8 @@ def py_type(self) -> str:
395393
package=self.output_file.package,
396394
imports=self.output_file.imports_end,
397395
source_type=self.proto_obj.type_name,
398-
typing_compiler=self.typing_compiler,
399396
request=self.output_file.parent_request,
400-
pydantic=self.output_file.pydantic_dataclasses,
397+
settings=self.output_file.settings,
401398
)
402399
else:
403400
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
@@ -582,10 +579,9 @@ def py_input_message_type(self) -> str:
582579
package=self.parent.output_file.package,
583580
imports=self.parent.output_file.imports_end,
584581
source_type=self.proto_obj.input_type,
585-
typing_compiler=self.parent.output_file.typing_compiler,
586582
request=self.parent.output_file.parent_request,
587583
unwrap=False,
588-
pydantic=self.parent.output_file.pydantic_dataclasses,
584+
settings=self.parent.output_file.settings,
589585
)
590586

591587
@property
@@ -621,10 +617,9 @@ def py_output_message_type(self) -> str:
621617
package=self.parent.output_file.package,
622618
imports=self.parent.output_file.imports_end,
623619
source_type=self.proto_obj.output_type,
624-
typing_compiler=self.parent.output_file.typing_compiler,
625620
request=self.parent.output_file.parent_request,
626621
unwrap=False,
627-
pydantic=self.parent.output_file.pydantic_dataclasses,
622+
settings=self.parent.output_file.settings,
628623
)
629624

630625
@property

src/betterproto2_compiler/plugin/parser.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CodeGeneratorResponseFeature,
1616
CodeGeneratorResponseFile,
1717
)
18+
from betterproto2_compiler.settings import Settings
1819

1920
from .compiler import outputfile_compiler
2021
from .models import (
@@ -64,11 +65,35 @@ def _traverse(
6465
yield from _traverse([4], proto_file.message_type)
6566

6667

68+
def get_settings(plugin_options: list[str]) -> Settings:
69+
# Gather any typing generation options.
70+
typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]
71+
72+
if len(typing_opts) > 1:
73+
raise ValueError("Multiple typing options provided")
74+
75+
# Set the compiler type.
76+
typing_opt = typing_opts[0] if typing_opts else "direct"
77+
if typing_opt == "direct":
78+
typing_compiler = DirectImportTypingCompiler()
79+
elif typing_opt == "root":
80+
typing_compiler = TypingImportTypingCompiler()
81+
elif typing_opt == "310":
82+
typing_compiler = NoTyping310TypingCompiler()
83+
else:
84+
raise ValueError("Invalid typing option provided")
85+
86+
return Settings(
87+
typing_compiler=typing_compiler,
88+
pydantic_dataclasses="pydantic_dataclasses" in plugin_options,
89+
)
90+
91+
6792
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
68-
response = CodeGeneratorResponse()
93+
response = CodeGeneratorResponse(supported_features=CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL)
6994

7095
plugin_options = request.parameter.split(",") if request.parameter else []
71-
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
96+
settings = get_settings(plugin_options)
7297

7398
request_data = PluginRequestCompiler(plugin_request_obj=request)
7499
# Gather output packages
@@ -77,8 +102,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
77102
if output_package_name not in request_data.output_packages:
78103
# Create a new output if there is no output for this package
79104
request_data.output_packages[output_package_name] = OutputTemplate(
80-
parent_request=request_data,
81-
package_proto_obj=proto_file,
105+
parent_request=request_data, package_proto_obj=proto_file, settings=settings
82106
)
83107
# Add this input file to the output corresponding to this package
84108
request_data.output_packages[output_package_name].input_files.append(proto_file)
@@ -88,23 +112,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
88112
# skip outputting Google's well-known types
89113
request_data.output_packages[output_package_name].output = False
90114

91-
if "pydantic_dataclasses" in plugin_options:
92-
request_data.output_packages[output_package_name].pydantic_dataclasses = True
93-
94-
# Gather any typing generation options.
95-
typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")]
96-
97-
if len(typing_opts) > 1:
98-
raise ValueError("Multiple typing options provided")
99-
# Set the compiler type.
100-
typing_opt = typing_opts[0] if typing_opts else "direct"
101-
if typing_opt == "direct":
102-
request_data.output_packages[output_package_name].typing_compiler = DirectImportTypingCompiler()
103-
elif typing_opt == "root":
104-
request_data.output_packages[output_package_name].typing_compiler = TypingImportTypingCompiler()
105-
elif typing_opt == "310":
106-
request_data.output_packages[output_package_name].typing_compiler = NoTyping310TypingCompiler()
107-
108115
# Read Messages and Enums
109116
# We need to read Messages before Services in so that we can
110117
# get the references to input/output messages for each service
@@ -199,7 +206,7 @@ def read_protobuf_type(
199206
message=message_data,
200207
proto_obj=field,
201208
path=path + [2, index],
202-
typing_compiler=output_package.typing_compiler,
209+
typing_compiler=output_package.settings.typing_compiler,
203210
)
204211
)
205212
elif is_oneof(field):
@@ -209,7 +216,7 @@ def read_protobuf_type(
209216
message=message_data,
210217
proto_obj=field,
211218
path=path + [2, index],
212-
typing_compiler=output_package.typing_compiler,
219+
typing_compiler=output_package.settings.typing_compiler,
213220
)
214221
)
215222
else:
@@ -219,7 +226,7 @@ def read_protobuf_type(
219226
message=message_data,
220227
proto_obj=field,
221228
path=path + [2, index],
222-
typing_compiler=output_package.typing_compiler,
229+
typing_compiler=output_package.settings.typing_compiler,
223230
)
224231
)
225232

src/betterproto2_compiler/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dataclasses import dataclass
2+
3+
from .plugin.typing_compiler import TypingCompiler
4+
5+
6+
@dataclass
7+
class Settings:
8+
pydantic_dataclasses: bool
9+
typing_compiler: TypingCompiler

src/betterproto2_compiler/templates/header.py.j2

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ import builtins
2222
import datetime
2323
import warnings
2424

25-
{% if output_file.pydantic_dataclasses %}
25+
{% if output_file.settings.pydantic_dataclasses %}
2626
from pydantic.dataclasses import dataclass
2727
from pydantic import model_validator
2828
{%- else -%}
2929
from dataclasses import dataclass
3030
{% endif %}
3131

32-
{% set typing_imports = output_file.typing_compiler.imports() %}
32+
{% set typing_imports = output_file.settings.typing_compiler.imports() %}
3333
{% if typing_imports %}
34-
{% for line in output_file.typing_compiler.import_lines() %}
34+
{% for line in output_file.settings.typing_compiler.import_lines() %}
3535
{{ line }}
3636
{% endfor %}
3737
{% endif %}

src/betterproto2_compiler/templates/template.py.j2

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class {{ enum.py_name }}(betterproto2.Enum):
1616

1717
{% endfor %}
1818

19-
{% if output_file.pydantic_dataclasses %}
19+
{% if output_file.settings.pydantic_dataclasses %}
2020
@classmethod
2121
def __get_pydantic_core_schema__(cls, _source_type, _handler):
2222
from pydantic_core import core_schema
@@ -26,7 +26,7 @@ class {{ enum.py_name }}(betterproto2.Enum):
2626

2727
{% endfor %}
2828
{% for _, message in output_file.messages|dictsort(by="key") %}
29-
{% if output_file.pydantic_dataclasses %}
29+
{% if output_file.settings.pydantic_dataclasses %}
3030
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
3131
{% else %}
3232
@dataclass(eq=False, repr=False)
@@ -70,7 +70,7 @@ class {{ message.py_name }}(betterproto2.Message):
7070
{% endfor %}
7171
{% endif %}
7272

73-
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
73+
{% if output_file.settings.pydantic_dataclasses and message.has_oneof_fields %}
7474
@model_validator(mode='after')
7575
def check_oneof(cls, values):
7676
return cls._validate_field_groups(values)
@@ -92,20 +92,20 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub):
9292
{%- if not method.client_streaming -%}
9393
, {{ method.py_input_message_param }}:
9494
{%- if method.is_input_msg_empty -%}
95-
"{{ method.py_input_message_type }} | None" = None
95+
"{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None
9696
{%- else -%}
9797
"{{ method.py_input_message_type }}"
9898
{%- endif -%}
9999
{%- else -%}
100100
{# Client streaming: need a request iterator instead #}
101-
, {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}"
101+
, {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}"
102102
{%- endif -%}
103103
,
104104
*
105-
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
106-
, deadline: "{{ output_file.typing_compiler.optional("Deadline") }}" = None
107-
, metadata: "{{ output_file.typing_compiler.optional("MetadataLike") }}" = None
108-
) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
105+
, timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None
106+
, deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None
107+
, metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None
108+
) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}":
109109
{% if method.comment %}
110110
"""
111111
{{ method.comment | indent(8) }}
@@ -194,9 +194,9 @@ class {{ service.py_name }}Base(ServiceBase):
194194
, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"
195195
{%- else -%}
196196
{# Client streaming: need a request iterator instead #}
197-
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
197+
, {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }}
198198
{%- endif -%}
199-
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
199+
) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
200200
{% if method.comment %}
201201
"""
202202
{{ method.comment | indent(8) }}
@@ -227,7 +227,7 @@ class {{ service.py_name }}Base(ServiceBase):
227227

228228
{% endfor %}
229229

230-
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
230+
def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}:
231231
return {
232232
{% for method in service.methods %}
233233
"{{ method.route }}": grpclib.const.Handler(

0 commit comments

Comments
 (0)