diff --git a/src/betterproto2_compiler/compile/importing.py b/src/betterproto2_compiler/compile/importing.py index 62d2b9c8..9f6b954d 100644 --- a/src/betterproto2_compiler/compile/importing.py +++ b/src/betterproto2_compiler/compile/importing.py @@ -83,7 +83,7 @@ def get_type_reference( if unwrap: if source_type in WRAPPER_TYPES: wrapped_type = type(WRAPPER_TYPES[source_type]().value) - return settings.typing_compiler.optional(wrapped_type.__name__) + return f"{wrapped_type.__name__} | None" if source_type == ".google.protobuf.Duration": return "datetime.timedelta" diff --git a/src/betterproto2_compiler/plugin/models.py b/src/betterproto2_compiler/plugin/models.py index 931e2226..4226bb46 100644 --- a/src/betterproto2_compiler/plugin/models.py +++ b/src/betterproto2_compiler/plugin/models.py @@ -57,7 +57,6 @@ ServiceDescriptorProto, ) from betterproto2_compiler.lib.google.protobuf.compiler import CodeGeneratorRequest -from betterproto2_compiler.plugin.typing_compiler import TypingCompiler from betterproto2_compiler.settings import Settings # Organize proto types into categories @@ -298,7 +297,6 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: @dataclass(kw_only=True) class FieldCompiler(ProtoContentBase): - typing_compiler: TypingCompiler builtins_types: set[str] = field(default_factory=set) message: MessageCompiler @@ -413,9 +411,9 @@ def annotation(self) -> str: if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: - return self.typing_compiler.list(py_type) + return f"list[{py_type}]" if self.optional: - return self.typing_compiler.optional(py_type) + return f"{py_type} | None" return py_type @@ -449,14 +447,12 @@ def ready(self) -> None: self.py_k_type = FieldCompiler( source_file=self.source_file, proto_obj=nested.field[0], # key - typing_compiler=self.typing_compiler, path=[], message=self.message, ).py_type self.py_v_type = FieldCompiler( source_file=self.source_file, proto_obj=nested.field[1], # value - typing_compiler=self.typing_compiler, path=[], message=self.message, ).py_type @@ -482,7 +478,7 @@ def get_field_string(self) -> str: @property def annotation(self) -> str: - return self.typing_compiler.dict(self.py_k_type, self.py_v_type) + return f"dict[{self.py_k_type}, {self.py_v_type}]" @property def repeated(self) -> bool: diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index 29dda0af..e114594a 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -31,11 +31,6 @@ is_map, is_oneof, ) -from .typing_compiler import ( - DirectImportTypingCompiler, - NoTyping310TypingCompiler, - TypingImportTypingCompiler, -) def traverse( @@ -65,25 +60,7 @@ def _traverse( def get_settings(plugin_options: list[str]) -> Settings: - # Gather any typing generation options. - typing_opts = [opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")] - - if len(typing_opts) > 1: - raise ValueError("Multiple typing options provided") - - # Set the compiler type. - typing_opt = typing_opts[0] if typing_opts else "direct" - if typing_opt == "direct": - typing_compiler = DirectImportTypingCompiler() - elif typing_opt == "root": - typing_compiler = TypingImportTypingCompiler() - elif typing_opt == "310": - typing_compiler = NoTyping310TypingCompiler() - else: - raise ValueError("Invalid typing option provided") - return Settings( - typing_compiler=typing_compiler, pydantic_dataclasses="pydantic_dataclasses" in plugin_options, ) @@ -203,7 +180,6 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.settings.typing_compiler, ) ) elif is_oneof(field): @@ -213,7 +189,6 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.settings.typing_compiler, ) ) else: @@ -223,7 +198,6 @@ def read_protobuf_type( message=message_data, proto_obj=field, path=path + [2, index], - typing_compiler=output_package.settings.typing_compiler, ) ) diff --git a/src/betterproto2_compiler/plugin/typing_compiler.py b/src/betterproto2_compiler/plugin/typing_compiler.py deleted file mode 100644 index fd1e120c..00000000 --- a/src/betterproto2_compiler/plugin/typing_compiler.py +++ /dev/null @@ -1,163 +0,0 @@ -import abc -import builtins -from collections import defaultdict -from collections.abc import Iterator -from dataclasses import ( - dataclass, - field, -) - - -class TypingCompiler(metaclass=abc.ABCMeta): - @abc.abstractmethod - def optional(self, type_: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def list(self, type_: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def dict(self, key: str, value: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def union(self, *types: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def iterable(self, type_: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def async_iterable(self, type_: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def async_iterator(self, type_: str) -> str: - raise NotImplementedError - - @abc.abstractmethod - def imports(self) -> builtins.dict[str, set[str] | None]: - """ - Returns either the direct import as a key with none as value, or a set of - values to import from the key. - """ - raise NotImplementedError - - def import_lines(self) -> Iterator: - imports = self.imports() - for key, value in imports.items(): - if value is None: - yield f"import {key}" - else: - yield f"from {key} import (" - for v in sorted(value): - yield f" {v}," - yield ")" - - -@dataclass -class DirectImportTypingCompiler(TypingCompiler): - _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) - - def optional(self, type_: str) -> str: - self._imports["typing"].add("Optional") - return f"Optional[{type_}]" - - def list(self, type_: str) -> str: - self._imports["typing"].add("List") - return f"List[{type_}]" - - def dict(self, key: str, value: str) -> str: - self._imports["typing"].add("Dict") - return f"Dict[{key}, {value}]" - - def union(self, *types: str) -> str: - self._imports["typing"].add("Union") - return f"Union[{', '.join(types)}]" - - def iterable(self, type_: str) -> str: - self._imports["typing"].add("Iterable") - return f"Iterable[{type_}]" - - def async_iterable(self, type_: str) -> str: - self._imports["typing"].add("AsyncIterable") - return f"AsyncIterable[{type_}]" - - def async_iterator(self, type_: str) -> str: - self._imports["typing"].add("AsyncIterator") - return f"AsyncIterator[{type_}]" - - def imports(self) -> builtins.dict[str, set[str] | None]: - return {k: v if v else None for k, v in self._imports.items()} - - -@dataclass -class TypingImportTypingCompiler(TypingCompiler): - _imported: bool = False - - def optional(self, type_: str) -> str: - self._imported = True - return f"typing.Optional[{type_}]" - - def list(self, type_: str) -> str: - self._imported = True - return f"typing.List[{type_}]" - - def dict(self, key: str, value: str) -> str: - self._imported = True - return f"typing.Dict[{key}, {value}]" - - def union(self, *types: str) -> str: - self._imported = True - return f"typing.Union[{', '.join(types)}]" - - def iterable(self, type_: str) -> str: - self._imported = True - return f"typing.Iterable[{type_}]" - - def async_iterable(self, type_: str) -> str: - self._imported = True - return f"typing.AsyncIterable[{type_}]" - - def async_iterator(self, type_: str) -> str: - self._imported = True - return f"typing.AsyncIterator[{type_}]" - - def imports(self) -> builtins.dict[str, set[str] | None]: - if self._imported: - return {"typing": None} - return {} - - -@dataclass -class NoTyping310TypingCompiler(TypingCompiler): - _imports: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) - - def optional(self, type_: str) -> str: - return f"{type_} | None" - - def list(self, type_: str) -> str: - return f"list[{type_}]" - - def dict(self, key: str, value: str) -> str: - return f"dict[{key}, {value}]" - - def union(self, *types: str) -> str: - return f"{' | '.join(types)}" - - def iterable(self, type_: str) -> str: - self._imports["collections.abc"].add("Iterable") - return f"Iterable[{type_}]" - - def async_iterable(self, type_: str) -> str: - self._imports["collections.abc"].add("AsyncIterable") - return f"AsyncIterable[{type_}]" - - def async_iterator(self, type_: str) -> str: - self._imports["collections.abc"].add("AsyncIterator") - return f"AsyncIterator[{type_}]" - - def imports(self) -> builtins.dict[str, set[str] | None]: - return {k: v if v else None for k, v in self._imports.items()} diff --git a/src/betterproto2_compiler/settings.py b/src/betterproto2_compiler/settings.py index 910f7dd8..394f3ced 100644 --- a/src/betterproto2_compiler/settings.py +++ b/src/betterproto2_compiler/settings.py @@ -1,9 +1,6 @@ from dataclasses import dataclass -from .plugin.typing_compiler import TypingCompiler - @dataclass class Settings: pydantic_dataclasses: bool - typing_compiler: TypingCompiler diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index c7658707..c73e7868 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -21,6 +21,8 @@ __all__ = ( import builtins import datetime import warnings +from collections.abc import AsyncIterable, AsyncIterator, Iterable +from typing import TYPE_CHECKING {% if output_file.settings.pydantic_dataclasses %} from pydantic.dataclasses import dataclass @@ -29,21 +31,12 @@ from pydantic import model_validator from dataclasses import dataclass {% endif %} -{% set typing_imports = output_file.settings.typing_compiler.imports() %} -{% if typing_imports %} -{% for line in output_file.settings.typing_compiler.import_lines() %} -{{ line }} -{% endfor %} -{% endif %} - import betterproto2 {% if output_file.services %} from betterproto2.grpc.grpclib_server import ServiceBase import grpclib {% endif %} -from typing import TYPE_CHECKING - {# Import the message pool of the generated code. #} {% if output_file.package %} from {{ "." * output_file.package.count(".") }}..message_pool import default_message_pool diff --git a/src/betterproto2_compiler/templates/template.py.j2 b/src/betterproto2_compiler/templates/template.py.j2 index 4a99a4dd..ed2304a6 100644 --- a/src/betterproto2_compiler/templates/template.py.j2 +++ b/src/betterproto2_compiler/templates/template.py.j2 @@ -100,20 +100,20 @@ class {{ service.py_name }}Stub(betterproto2.ServiceStub): {%- if not method.client_streaming -%} , message: {%- if method.is_input_msg_empty -%} - "{{ output_file.settings.typing_compiler.optional(method.py_input_message_type) }}" = None + "{{ method.py_input_message_type }} | None" = None {%- else -%} "{{ method.py_input_message_type }}" {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , messages: "{{ output_file.settings.typing_compiler.union(output_file.settings.typing_compiler.async_iterable(method.py_input_message_type), output_file.settings.typing_compiler.iterable(method.py_input_message_type)) }}" + , messages: "AsyncIterable[{{ method.py_input_message_type }}] | Iterable[{{ method.py_input_message_type }}]" {%- endif -%} , * - , timeout: {{ output_file.settings.typing_compiler.optional("float") }} = None - , deadline: "{{ output_file.settings.typing_compiler.optional("Deadline") }}" = None - , metadata: "{{ output_file.settings.typing_compiler.optional("MetadataLike") }}" = None - ) -> "{% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}{{ method.py_output_message_type }}{% endif %}": + , timeout: "float | None" = None + , deadline: "Deadline | None" = None + , metadata: "MetadataLike | None" = None + ) -> "{% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}": {% if method.comment %} """ {{ method.comment | indent(8) }} @@ -202,9 +202,9 @@ class {{ service.py_name }}Base(ServiceBase): , message: "{{ method.py_input_message_type }}" {%- else -%} {# Client streaming: need a request iterator instead #} - , messages: {{ output_file.settings.typing_compiler.async_iterator(method.py_input_message_type) }} + , messages: "AsyncIterator[{{ method.py_input_message_type }}]" {%- endif -%} - ) -> {% if method.server_streaming %}{{ output_file.settings.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: + ) -> {% if method.server_streaming %}"AsyncIterator[{{ method.py_output_message_type }}]"{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} """ {{ method.comment | indent(8) }} @@ -235,7 +235,7 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} - def __mapping__(self) -> {{ output_file.settings.typing_compiler.dict("str", "grpclib.const.Handler") }}: + def __mapping__(self) -> "dict[str, grpclib.const.Handler": return { {% for method in service.methods %} "{{ method.route }}": grpclib.const.Handler( diff --git a/tests/test_typing_compiler.py b/tests/test_typing_compiler.py deleted file mode 100644 index 470b6e54..00000000 --- a/tests/test_typing_compiler.py +++ /dev/null @@ -1,70 +0,0 @@ -from betterproto2_compiler.plugin.typing_compiler import ( - DirectImportTypingCompiler, - NoTyping310TypingCompiler, - TypingImportTypingCompiler, -) - - -def test_direct_import_typing_compiler(): - compiler = DirectImportTypingCompiler() - assert compiler.imports() == {} - assert compiler.optional("str") == "Optional[str]" - assert compiler.imports() == {"typing": {"Optional"}} - assert compiler.list("str") == "List[str]" - assert compiler.imports() == {"typing": {"Optional", "List"}} - assert compiler.dict("str", "int") == "Dict[str, int]" - assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}} - assert compiler.union("str", "int") == "Union[str, int]" - assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}} - assert compiler.iterable("str") == "Iterable[str]" - assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union", "Iterable"}} - assert compiler.async_iterable("str") == "AsyncIterable[str]" - assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}} - assert compiler.async_iterator("str") == "AsyncIterator[str]" - assert compiler.imports() == { - "typing": { - "Optional", - "List", - "Dict", - "Union", - "Iterable", - "AsyncIterable", - "AsyncIterator", - }, - } - - -def test_typing_import_typing_compiler(): - compiler = TypingImportTypingCompiler() - assert compiler.imports() == {} - assert compiler.optional("str") == "typing.Optional[str]" - assert compiler.imports() == {"typing": None} - assert compiler.list("str") == "typing.List[str]" - assert compiler.imports() == {"typing": None} - assert compiler.dict("str", "int") == "typing.Dict[str, int]" - assert compiler.imports() == {"typing": None} - assert compiler.union("str", "int") == "typing.Union[str, int]" - assert compiler.imports() == {"typing": None} - assert compiler.iterable("str") == "typing.Iterable[str]" - assert compiler.imports() == {"typing": None} - assert compiler.async_iterable("str") == "typing.AsyncIterable[str]" - assert compiler.imports() == {"typing": None} - assert compiler.async_iterator("str") == "typing.AsyncIterator[str]" - assert compiler.imports() == {"typing": None} - - -def test_no_typing_311_typing_compiler(): - compiler = NoTyping310TypingCompiler() - assert compiler.imports() == {} - assert compiler.optional("str") == "str | None" - assert compiler.imports() == {} - assert compiler.list("str") == "list[str]" - assert compiler.imports() == {} - assert compiler.dict("str", "int") == "dict[str, int]" - assert compiler.imports() == {} - assert compiler.union("str", "int") == "str | int" - assert compiler.imports() == {} - assert compiler.iterable("str") == "Iterable[str]" - assert compiler.async_iterable("str") == "AsyncIterable[str]" - assert compiler.async_iterator("str") == "AsyncIterator[str]" - assert compiler.imports() == {"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}}