From d3878c4fbe8ed0dfd2a316db2718bd09bfa57543 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 10 Jan 2025 14:50:25 +0100 Subject: [PATCH] Create and use Settings class --- .../compile/importing.py | 10 ++-- src/betterproto2_compiler/plugin/models.py | 19 +++---- src/betterproto2_compiler/plugin/parser.py | 55 +++++++++++-------- src/betterproto2_compiler/settings.py | 9 +++ .../templates/header.py.j2 | 6 +- .../templates/template.py.j2 | 24 ++++---- 6 files changed, 67 insertions(+), 56 deletions(-) create mode 100644 src/betterproto2_compiler/settings.py diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index 266193e8..ee11699c 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -7,12 +7,13 @@ from betterproto2.lib.google import protobuf as google_protobuf +from betterproto2_compiler.settings import Settings + from ..casing import safe_snake_case from .naming import pythonize_class_name if TYPE_CHECKING: from ..plugin.models import PluginRequestCompiler - from ..plugin.typing_compiler import TypingCompiler WRAPPER_TYPES: dict[str, type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, @@ -72,10 +73,9 @@ def get_type_reference( package: str, imports: set, source_type: str, - typing_compiler: TypingCompiler, request: PluginRequestCompiler, unwrap: bool = True, - pydantic: bool = False, + settings: Settings, ) -> str: """ Return a Python type name for a proto type reference. Adds the import if @@ -84,7 +84,7 @@ def get_type_reference( if unwrap: if source_type in WRAPPER_TYPES: wrapped_type = type(WRAPPER_TYPES[source_type]().value) - return typing_compiler.optional(wrapped_type.__name__) + return settings.typing_compiler.optional(wrapped_type.__name__) if source_type == ".google.protobuf.Duration": return "datetime.timedelta" @@ -101,7 +101,7 @@ def get_type_reference( compiling_google_protobuf = current_package == ["google", "protobuf"] importing_google_protobuf = py_package == ["google", "protobuf"] if importing_google_protobuf and not compiling_google_protobuf: - py_package = ["betterproto2", "lib"] + (["pydantic"] if pydantic else []) + py_package + py_package = ["betterproto2", "lib"] + (["pydantic"] if settings.pydantic_dataclasses else []) + py_package if py_package[:1] == ["betterproto2"]: return reference_absolute(imports, py_package, py_type) diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 1e433b30..67cffc93 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -53,6 +53,7 @@ pythonize_method_name, ) from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest +from betterproto2_compiler.settings import Settings from ..compile.importing import get_type_reference, parse_source_type_name from ..compile.naming import ( @@ -61,10 +62,7 @@ pythonize_field_name, pythonize_method_name, ) -from .typing_compiler import ( - DirectImportTypingCompiler, - TypingCompiler, -) +from .typing_compiler import TypingCompiler # Organize proto types into categories PROTO_FLOAT_TYPES = ( @@ -195,9 +193,9 @@ class OutputTemplate: messages: dict[str, "MessageCompiler"] = field(default_factory=dict) enums: dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) services: dict[str, "ServiceCompiler"] = field(default_factory=dict) - pydantic_dataclasses: bool = False output: bool = True - typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) + + settings: Settings @property def package(self) -> str: @@ -395,9 +393,8 @@ def py_type(self) -> str: package=self.output_file.package, imports=self.output_file.imports_end, source_type=self.proto_obj.type_name, - typing_compiler=self.typing_compiler, request=self.output_file.parent_request, - pydantic=self.output_file.pydantic_dataclasses, + settings=self.output_file.settings, ) else: raise NotImplementedError(f"Unknown type {self.proto_obj.type}") @@ -582,10 +579,9 @@ def py_input_message_type(self) -> str: package=self.parent.output_file.package, imports=self.parent.output_file.imports_end, source_type=self.proto_obj.input_type, - typing_compiler=self.parent.output_file.typing_compiler, request=self.parent.output_file.parent_request, unwrap=False, - pydantic=self.parent.output_file.pydantic_dataclasses, + settings=self.parent.output_file.settings, ) @property @@ -621,10 +617,9 @@ def py_output_message_type(self) -> str: package=self.parent.output_file.package, imports=self.parent.output_file.imports_end, source_type=self.proto_obj.output_type, - typing_compiler=self.parent.output_file.typing_compiler, request=self.parent.output_file.parent_request, unwrap=False, - pydantic=self.parent.output_file.pydantic_dataclasses, + settings=self.parent.output_file.settings, ) @property diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index 157d13b9..fabca1bd 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -15,6 +15,7 @@ CodeGeneratorResponseFeature, CodeGeneratorResponseFile, ) +from betterproto2_compiler.settings import Settings from .compiler import outputfile_compiler from .models import ( @@ -64,11 +65,35 @@ def _traverse( yield from _traverse([4], proto_file.message_type) +def get_settings(plugin_options: list[str]) -> Settings: + # Gather any typing generation options. + typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")] + + if len(typing_opts) > 1: + raise ValueError("Multiple typing options provided") + + # Set the compiler type. + typing_opt = typing_opts[0] if typing_opts else "direct" + if typing_opt == "direct": + typing_compiler = DirectImportTypingCompiler() + elif typing_opt == "root": + typing_compiler = TypingImportTypingCompiler() + elif typing_opt == "310": + typing_compiler = NoTyping310TypingCompiler() + else: + raise ValueError("Invalid typing option provided") + + return Settings( + typing_compiler=typing_compiler, + pydantic_dataclasses="pydantic_dataclasses" in plugin_options, + ) + + def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: - response = CodeGeneratorResponse() + response = CodeGeneratorResponse(supported_features=CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL) plugin_options = request.parameter.split(",") if request.parameter else [] - response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL + settings = get_settings(plugin_options) request_data = PluginRequestCompiler(plugin_request_obj=request) # Gather output packages @@ -77,8 +102,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: if output_package_name not in request_data.output_packages: # Create a new output if there is no output for this package request_data.output_packages[output_package_name] = OutputTemplate( - parent_request=request_data, - package_proto_obj=proto_file, + parent_request=request_data, package_proto_obj=proto_file, settings=settings ) # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) @@ -88,23 +112,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # skip outputting Google's well-known types request_data.output_packages[output_package_name].output = False - if "pydantic_dataclasses" in plugin_options: - request_data.output_packages[output_package_name].pydantic_dataclasses = True - - # Gather any typing generation options. - typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")] - - if len(typing_opts) > 1: - raise ValueError("Multiple typing options provided") - # Set the compiler type. - typing_opt = typing_opts[0] if typing_opts else "direct" - if typing_opt == "direct": - request_data.output_packages[output_package_name].typing_compiler = DirectImportTypingCompiler() - elif typing_opt == "root": - request_data.output_packages[output_package_name].typing_compiler = TypingImportTypingCompiler() - elif typing_opt == "310": - request_data.output_packages[output_package_name].typing_compiler = NoTyping310TypingCompiler() - # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service @@ -199,7 +206,7 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.typing_compiler, + typing_compiler=output_package.settings.typing_compiler, ) ) elif is_oneof(field): @@ -209,7 +216,7 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.typing_compiler, + typing_compiler=output_package.settings.typing_compiler, ) ) else: @@ -219,7 +226,7 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.typing_compiler, + typing_compiler=output_package.settings.typing_compiler, ) ) diff --git a/src/betterproto2_compiler/settings.py b/src/betterproto2_compiler/settings.py new file mode 100644 index 00000000..910f7dd8 --- /dev/null +++ b/src/betterproto2_compiler/settings.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + +from .plugin.typing_compiler import TypingCompiler + + +@dataclass +class Settings: + pydantic_dataclasses: bool + typing_compiler: TypingCompiler diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index 4765f73d..b1ffdd83 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -22,16 +22,16 @@ import builtins import datetime import warnings -{% if output_file.pydantic_dataclasses %} +{% if output_file.settings.pydantic_dataclasses %} from pydantic.dataclasses import dataclass from pydantic import model_validator {%- else -%} from dataclasses import dataclass {% endif %} -{% set typing_imports = output_file.typing_compiler.imports() %} +{% set typing_imports = output_file.settings.typing_compiler.imports() %} {% if typing_imports %} -{% for line in output_file.typing_compiler.import_lines() %} +{% for line in output_file.settings.typing_compiler.import_lines() %} {{ line }} {% endfor %} {% endif %} diff --git a/src/betterproto2_compiler/templates/template.py.j2 b/src/betterproto2_compiler/templates/template.py.j2 index 6a52bd93..3ff8b32d 100644 --- a/src/betterproto2_compiler/templates/template.py.j2 +++ b/src/betterproto2_compiler/templates/template.py.j2 @@ -16,7 +16,7 @@ class {{ enum.py_name }}(betterproto2.Enum): {% endfor %} - {% if output_file.pydantic_dataclasses %} + {% if output_file.settings.pydantic_dataclasses %} @classmethod def __get_pydantic_core_schema__(cls, _source_type, _handler): from pydantic_core import core_schema @@ -26,7 +26,7 @@ class {{ enum.py_name }}(betterproto2.Enum): {% endfor %} {% for _, message in output_file.messages|dictsort(by="key") %} -{% if output_file.pydantic_dataclasses %} +{% if output_file.settings.pydantic_dataclasses %} @dataclass(eq=False, repr=False, config={"extra": "forbid"}) {% else %} @dataclass(eq=False, repr=False) @@ -70,7 +70,7 @@ class {{ message.py_name }}(betterproto2.Message): {% endfor %} {% endif %} - {% if output_file.pydantic_dataclasses and message.has_oneof_fields %} + {% if output_file.settings.pydantic_dataclasses and message.has_oneof_fields %} @model_validator(mode='after') def check_oneof(cls, values): return cls._validate_field_groups(values) @@ -92,20 +92,20 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {%- if not method.client_streaming -%} , {{ method.py_input_message_param }}: {%- if method.is_input_msg_empty -%} - "{{ method.py_input_message_type }} | None" = None + "{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None {%- else -%} "{{ method.py_input_message_type }}" {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: "{{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}" + , {{ method.py_input_message_param }}_iterator: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}" {%- endif -%} , * - , timeout: {{ output_file.typing_compiler.optional("float") }} = None - , deadline: "{{ output_file.typing_compiler.optional("Deadline") }}" = None - , metadata: "{{ output_file.typing_compiler.optional("MetadataLike") }}" = None - ) -> "{% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": + , timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None + , deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None + , metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None + ) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} """ {{ method.comment | indent(8) }} @@ -194,9 +194,9 @@ class {{ service.py_name }}Base(ServiceBase): , {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} + , {{ method.py_input_message_param }}_iterator: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }} {%- endif -%} - ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} """ {{ method.comment | indent(8) }} @@ -227,7 +227,7 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} - def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}: + def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}: return { {% for method in service.methods %} "{{ method.route }}": grpclib.const.Handler(