diff --git a/poetry.lock b/poetry.lock index 1fcf9245..73be47d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -57,6 +57,7 @@ files = [ ] [package.dependencies] +grpclib = {version = "*", optional = true, markers = "extra == \"grpc-async\""} python-dateutil = ">=2.8,<3.0" typing-extensions = ">=4.7.1,<5.0.0" @@ -617,19 +618,19 @@ protobuf = ["protobuf (>=3.20.0)"] [[package]] name = "h2" -version = "4.1.0" -description = "HTTP/2 State-Machine based protocol implementation" +version = "4.2.0" +description = "Pure-Python HTTP/2 protocol implementation" optional = false -python-versions = ">=3.6.1" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, - {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, + {file = "h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0"}, + {file = "h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f"}, ] [package.dependencies] -hpack = ">=4.0,<5" -hyperframe = ">=6.0,<7" +hpack = ">=4.1,<5" +hyperframe = ">=6.1,<7" [[package]] name = "hpack" @@ -2040,6 +2041,24 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "strenum" +version = "0.4.15" +description = "An Enum that inherits from str." +optional = false +python-versions = "*" +groups = ["main"] +markers = "python_version == \"3.10\"" +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[package.extras] +docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] +release = ["twine"] +test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] + [[package]] name = "tomli" version = "2.2.1" @@ -2229,4 +2248,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "199339c6e33ce9a1cfd1b3864790835240f695651004dcdcec8909d97aa30ca8" +content-hash = "3b4c4984f013afc8bfe33cd40a1d9a61fc70fbebc1228365abcd50a86e48a029" diff --git a/pyproject.toml b/pyproject.toml index c652868a..e6ed00c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,15 @@ packages = [ [tool.poetry.dependencies] python = "^3.10" -betterproto2 = "^0.3.1" +betterproto2 = { version = "^0.3.1", extras = ["grpc-async"] } # betterproto2 = { git="https://github.com/betterproto/python-betterproto2" } # The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml ruff = "~0.9.3" -grpclib = "^0.4.1" jinja2 = ">=3.0.3" typing-extensions = "^4.7.1" +strenum = [ + {version = "^0.4.15", python = "=3.10"}, +] [tool.poetry.group.dev.dependencies] pre-commit = "^2.17.0" diff --git a/src/betterproto2_compiler/plugin/parser.py b/src/betterproto2_compiler/plugin/parser.py index e114594a..67405006 100644 --- a/src/betterproto2_compiler/plugin/parser.py +++ b/src/betterproto2_compiler/plugin/parser.py @@ -14,7 +14,7 @@ CodeGeneratorResponseFeature, CodeGeneratorResponseFile, ) -from betterproto2_compiler.settings import Settings +from betterproto2_compiler.settings import ClientGeneration, Settings from .compiler import outputfile_compiler from .models import ( @@ -60,8 +60,24 @@ def _traverse( def get_settings(plugin_options: list[str]) -> Settings: + # Synchronous clients are suitable for most users + client_generation = ClientGeneration.SYNC + + for opt in plugin_options: + if opt.startswith("client_generation="): + name = opt.split("=")[1] + + # print(ClientGeneration.__members__, file=sys.stderr) + # print([member.value for member in ClientGeneration]) + + try: + client_generation = ClientGeneration(name) + except ValueError: + raise ValueError(f"Invalid client_generation option: {name}") + return Settings( pydantic_dataclasses="pydantic_dataclasses" in plugin_options, + client_generation=client_generation, ) diff --git a/src/betterproto2_compiler/settings.py b/src/betterproto2_compiler/settings.py index 394f3ced..fc230875 100644 --- a/src/betterproto2_compiler/settings.py +++ b/src/betterproto2_compiler/settings.py @@ -1,6 +1,67 @@ +import sys from dataclasses import dataclass +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from strenum import StrEnum + + +class ClientGeneration(StrEnum): + NONE = "none" + """Clients are not generated.""" + + SYNC = "sync" + """Only synchronous clients are generated.""" + + ASYNC = "async" + """Only asynchronous clients are generated.""" + + SYNC_ASYNC = "sync_async" + """Both synchronous and asynchronous clients are generated. + + The asynchronous client is generated with the Async suffix.""" + + ASYNC_SYNC = "async_sync" + """Both synchronous and asynchronous clients are generated. + + The synchronous client is generated with the Sync suffix.""" + + SYNC_ASYNC_NO_DEFAULT = "sync_async_no_default" + """Both synchronous and asynchronous clients are generated. + + The synchronous client is generated with the Sync suffix, and the asynchronous client is generated with the Async + suffix.""" + + @property + def is_sync_generated(self) -> bool: + return self in { + ClientGeneration.SYNC, + ClientGeneration.SYNC_ASYNC, + ClientGeneration.ASYNC_SYNC, + ClientGeneration.SYNC_ASYNC_NO_DEFAULT, + } + + @property + def is_async_generated(self) -> bool: + return self in { + ClientGeneration.ASYNC, + ClientGeneration.SYNC_ASYNC, + ClientGeneration.ASYNC_SYNC, + ClientGeneration.SYNC_ASYNC_NO_DEFAULT, + } + + @property + def is_sync_prefixed(self) -> bool: + return self in {ClientGeneration.ASYNC_SYNC, ClientGeneration.SYNC_ASYNC_NO_DEFAULT} + + @property + def is_async_prefixed(self) -> bool: + return self in {ClientGeneration.SYNC_ASYNC, ClientGeneration.SYNC_ASYNC_NO_DEFAULT} + @dataclass class Settings: pydantic_dataclasses: bool + + client_generation: ClientGeneration diff --git a/src/betterproto2_compiler/templates/header.py.j2 b/src/betterproto2_compiler/templates/header.py.j2 index f7d12d24..7d5f7cb0 100644 --- a/src/betterproto2_compiler/templates/header.py.j2 +++ b/src/betterproto2_compiler/templates/header.py.j2 @@ -21,7 +21,7 @@ __all__ = ( import builtins import datetime import warnings -from collections.abc import AsyncIterable, AsyncIterator, Iterable +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Iterator import typing from typing import TYPE_CHECKING @@ -33,10 +33,9 @@ from dataclasses import dataclass {% endif %} import betterproto2 -{% if output_file.services %} from betterproto2.grpc.grpclib_server import ServiceBase +import grpc import grpclib -{% endif %} {# Import the message pool of the generated code. #} {% if output_file.package %} diff --git a/src/betterproto2_compiler/templates/service_stub.py.j2 b/src/betterproto2_compiler/templates/service_stub.py.j2 new file mode 100644 index 00000000..f0958055 --- /dev/null +++ b/src/betterproto2_compiler/templates/service_stub.py.j2 @@ -0,0 +1,32 @@ +class {% block class_name %}{% endblock %}({% block inherit_from %}{% endblock %}): + {% block service_docstring scoped %} + {% if service.comment %} + """ + {{ service.comment | indent(4) }} + """ + {% elif not service.methods %} + pass + {% endif %} + {% endblock %} + + {% block class_content %}{% endblock %} + + {% for method in service.methods %} + {% block method_definition scoped required %}{% endblock %} + {% block method_docstring scoped %} + {% if method.comment %} + """ + {{ method.comment | indent(8) }} + """ + {% endif %} + {% endblock %} + + {% block deprecation_warning scoped %} + {% if method.deprecated %} + warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning) + {% endif %} + {% endblock %} + + {% block method_body scoped required %}{% endblock %} + + {% endfor %} \ No newline at end of file diff --git a/src/betterproto2_compiler/templates/service_stub_async.py.j2 b/src/betterproto2_compiler/templates/service_stub_async.py.j2 new file mode 100644 index 00000000..191d6a94 --- /dev/null +++ b/src/betterproto2_compiler/templates/service_stub_async.py.j2 @@ -0,0 +1,86 @@ +{% extends "service_stub.py.j2" %} + +{# Class definition #} +{% block class_name %}{{ service.py_name }}{% if output_file.settings.client_generation.is_async_prefixed %}Async{% endif %}Stub{% endblock %} +{% block inherit_from %}betterproto2.ServiceStub{% endblock %} + +{# Methods definition #} +{% block method_definition %} + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , message: + {%- if method.is_input_msg_empty -%} + "{{ method.py_input_message_type }} | None" = None + {%- else -%} + "{{ method.py_input_message_type }}" + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , messages: "AsyncIterable[{{ method.py_input_message_type }}] | Iterable[{{ method.py_input_message_type }}]" + {%- endif -%} + , + * + , timeout: "float | None" = None + , deadline: "Deadline | None" = None + , metadata: "MetadataLike | None" = None + ) -> "{% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}": +{% endblock %} + +{% block method_body %} + {% if method.server_streaming %} + {% if method.client_streaming %} + async for response in self._stream_stream( + "{{ method.route }}", + messages, + {{ method.py_input_message_type }}, + {{ method.py_output_message_type }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ): + yield response + {% else %} + {% if method.is_input_msg_empty %} + if message is None: + message = {{ method.py_input_message_type }}() + + {% endif %} + async for response in self._unary_stream( + "{{ method.route }}", + message, + {{ method.py_output_message_type }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ): + yield response + + {% endif %} + {% else %} + {% if method.client_streaming %} + return await self._stream_unary( + "{{ method.route }}", + messages, + {{ method.py_input_message_type }}, + {{ method.py_output_message_type }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ) + {% else %} + {% if method.is_input_msg_empty %} + if message is None: + message = {{ method.py_input_message_type }}() + + {% endif %} + return await self._unary_unary( + "{{ method.route }}", + message, + {{ method.py_output_message_type }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, + ) + {% endif %} + {% endif %} +{% endblock %} \ No newline at end of file diff --git a/src/betterproto2_compiler/templates/service_stub_sync.py.j2 b/src/betterproto2_compiler/templates/service_stub_sync.py.j2 new file mode 100644 index 00000000..97647025 --- /dev/null +++ b/src/betterproto2_compiler/templates/service_stub_sync.py.j2 @@ -0,0 +1,72 @@ +{% extends "service_stub.py.j2" %} + +{# Class definition #} +{% block class_name %}{{ service.py_name }}{% if output_file.settings.client_generation.is_sync_prefixed %}Sync{% endif %}Stub{% endblock %} + +{% block class_content %} + {# TODO move to parent class #} + def __init__(self, channel: grpc.Channel): + self._channel = channel +{% endblock %} + +{# Methods definition #} +{% block method_definition %} + def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + , message: + {%- if method.is_input_msg_empty -%} + "{{ method.py_input_message_type }} | None" = None + {%- else -%} + "{{ method.py_input_message_type }}" + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , messages: "Iterable[{{ method.py_input_message_type }}]" + {%- endif -%} + ) -> "{% if method.server_streaming %}Iterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}": +{% endblock %} + +{% block method_body %} + {% if method.server_streaming %} + {% if method.client_streaming %} + for response in self._channel.stream_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )(iter(messages)): + yield response + {% else %} + {% if method.is_input_msg_empty %} + if message is None: + message = {{ method.py_input_message_type }}() + + {% endif %} + for response in self._channel.unary_stream( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )(message): + yield response + + {% endif %} + {% else %} + {% if method.client_streaming %} + return self._channel.stream_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )(iter(messages)) + {% else %} + {% if method.is_input_msg_empty %} + if message is None: + message = {{ method.py_input_message_type }}() + + {% endif %} + return self._channel.unary_unary( + "{{ method.route }}", + {{ method.py_input_message_type }}.SerializeToString, + {{ method.py_output_message_type }}.FromString, + )(message) + {% endif %} + {% endif %} +{% endblock %} \ No newline at end of file diff --git a/src/betterproto2_compiler/templates/template.py.j2 b/src/betterproto2_compiler/templates/template.py.j2 index 7ee3718b..1281e679 100644 --- a/src/betterproto2_compiler/templates/template.py.j2 +++ b/src/betterproto2_compiler/templates/template.py.j2 @@ -85,103 +85,17 @@ default_message_pool.register_message("{{ output_file.package }}", "{{ message.p {% endfor %} -{% for _, service in output_file.services|dictsort(by="key") %} -class {{ service.py_name }}Stub(betterproto2.ServiceStub): - {% if service.comment %} - """ - {{ service.comment | indent(4) }} - """ - {% elif not service.methods %} - pass - {% endif %} - - {% for method in service.methods %} - async def {{ method.py_name }}(self - {%- if not method.client_streaming -%} - , message: - {%- if method.is_input_msg_empty -%} - "{{ method.py_input_message_type }} | None" = None - {%- else -%} - "{{ method.py_input_message_type }}" - {%- endif -%} - {%- else -%} - {# Client streaming: need a request iterator instead #} - , messages: "AsyncIterable[{{ method.py_input_message_type }}] | Iterable[{{ method.py_input_message_type }}]" - {%- endif -%} - , - * - , timeout: "float | None" = None - , deadline: "Deadline | None" = None - , metadata: "MetadataLike | None" = None - ) -> "{% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}": - {% if method.comment %} - """ - {{ method.comment | indent(8) }} - """ - {% endif %} - {% if method.deprecated %} - warnings.warn("{{ service.py_name }}.{{ method.py_name }} is deprecated", DeprecationWarning) - {% endif %} - - {% if method.server_streaming %} - {% if method.client_streaming %} - async for response in self._stream_stream( - "{{ method.route }}", - messages, - {{ method.py_input_message_type }}, - {{ method.py_output_message_type }}, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response - {% else %}{# i.e. not client streaming #} - {% if method.is_input_msg_empty %} - if message is None: - message = {{ method.py_input_message_type }}() - - {% endif %} - async for response in self._unary_stream( - "{{ method.route }}", - message, - {{ method.py_output_message_type }}, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ): - yield response +{% for _, service in output_file.services|dictsort(by="key") %} - {% endif %}{# if client streaming #} - {% else %}{# i.e. not server streaming #} - {% if method.client_streaming %} - return await self._stream_unary( - "{{ method.route }}", - messages, - {{ method.py_input_message_type }}, - {{ method.py_output_message_type }}, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ) - {% else %}{# i.e. not client streaming #} - {% if method.is_input_msg_empty %} - if message is None: - message = {{ method.py_input_message_type }}() +{% if output_file.settings.client_generation.is_sync_generated %} +{% include "service_stub_sync.py.j2" %} +{% endif %} - {% endif %} - return await self._unary_unary( - "{{ method.route }}", - message, - {{ method.py_output_message_type }}, - timeout=timeout, - deadline=deadline, - metadata=metadata, - ) - {% endif %}{# client streaming #} - {% endif %} +{% if output_file.settings.client_generation.is_async_generated %} +{% include "service_stub_async.py.j2" %} +{% endif %} - {% endfor %} {% endfor %} {% for i in output_file.imports_end %} diff --git a/tests/inputs/simple_service/simple_service.proto b/tests/inputs/simple_service/simple_service.proto new file mode 100644 index 00000000..fc8e82b6 --- /dev/null +++ b/tests/inputs/simple_service/simple_service.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package simple_service; + +message Request { + int32 value = 1; +} + +message Response { + string message = 1; +} + +service SimpleService { + rpc GetUnaryUnary (Request) returns (Response); + + rpc GetUnaryStream (Request) returns (stream Response); + + rpc GetStreamUnary (stream Request) returns (Response); + + rpc GetStreamStream (stream Request) returns (stream Response); +} diff --git a/tests/util.py b/tests/util.py index 63d189f6..d4772c7e 100644 --- a/tests/util.py +++ b/tests/util.py @@ -63,6 +63,7 @@ async def protoc( f"--plugin=protoc-gen-custom={plugin_path.as_posix()}", "--experimental_allow_proto3_optional", "--custom_opt=pydantic_dataclasses", + "--custom_opt=client_generation=async_sync", f"--proto_path={path.as_posix()}", f"--custom_out={output_dir.as_posix()}", *[p.as_posix() for p in path.glob("*.proto")], @@ -76,6 +77,10 @@ async def protoc( f"--{python_out_option}={output_dir.as_posix()}", *[p.as_posix() for p in path.glob("*.proto")], ] + + if not reference: + command.insert(3, "--python_betterproto2_opt=client_generation=async_sync") + proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE,