Skip to content

Add grpcio support. #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,12 @@ protoc \
/usr/local/include/google/protobuf/*.proto
```

### Using grpcio library instead of grpclib

In order to use the `grpcio` library instead of `grpclib`, you can use the `--custom_opt=grpcio`
option when running the `protoc` command.
This will generate stubs compatible with the `grpcio` library.

### TODO

- [x] Fixed length fields
Expand Down
172 changes: 94 additions & 78 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dynamic = ["dependencies"]
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
ruff = { version = "~0.9.1", optional = true }
grpclib = "^0.4.1"
grpcio = { version = ">=1.73.0", optional = true }
jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8"
typing-extensions = "^4.7.1"
Expand All @@ -45,13 +46,15 @@ pydantic = ">=2.0,<3"
protobuf = "^5"
cachelib = "^0.13.0"
tomlkit = ">=0.7.0"
grpcio-testing = "^1.54.2"

[project.scripts]
protoc-gen-python_betterproto = "betterproto.plugin:main"

[project.optional-dependencies]
compiler = ["ruff", "jinja2"]
rust-codec = ["betterproto-rust-codec"]
grpcio = ["grpcio"]

[tool.ruff]
extend-exclude = ["tests/output_*"]
Expand Down
120 changes: 120 additions & 0 deletions src/betterproto/grpc/grpcio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from abc import ABC
from typing import (
TYPE_CHECKING,
AsyncIterable,
AsyncIterator,
Iterable,
Mapping,
Optional,
Union,
)

import grpc

if TYPE_CHECKING:
from .._types import (
T,
IProtoMessage,
)

Value = Union[str, bytes]
MetadataLike = Union[Mapping[str, Value], Iterable[tuple[str, Value]]]
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]


class ServiceStub(ABC):

def __init__(
self,
channel: grpc.aio.Channel,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> None:
self.channel = channel
self.timeout = timeout
self.metadata = metadata

def _resolve_request_kwargs(
self,
timeout: Optional[float],
metadata: Optional[MetadataLike],
):
return {
"timeout": self.timeout if timeout is None else timeout,
"metadata": self.metadata if metadata is None else metadata,
}

async def _unary_unary(
self,
stub_method: grpc.aio.UnaryUnaryMultiCallable,
request: "IProtoMessage",
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> "T":
return await stub_method(
request,
**self._resolve_request_kwargs(timeout, metadata),
)

async def _unary_stream(
self,
stub_method: grpc.aio.UnaryStreamMultiCallable,
request: "IProtoMessage",
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator["T"]:
call = stub_method(
request,
**self._resolve_request_kwargs(timeout, metadata),
)
async for response in call:
yield response

async def _stream_unary(
self,
stub_method: grpc.aio.StreamUnaryMultiCallable,
request_iterator: MessageSource,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> "T":
call = stub_method(
self._wrap_message_iterator(request_iterator),
**self._resolve_request_kwargs(timeout, metadata),
)
return await call

async def _stream_stream(
self,
stub_method: grpc.aio.StreamStreamMultiCallable,
request_iterator: MessageSource,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator["T"]:
call = stub_method(
self._wrap_message_iterator(request_iterator),
**self._resolve_request_kwargs(timeout, metadata),
)
async for response in call:
yield response

@staticmethod
def _wrap_message_iterator(
messages: MessageSource,
) -> AsyncIterator["IProtoMessage"]:
if hasattr(messages, '__aiter__'):
async def async_wrapper():
async for message in messages:
yield message

return async_wrapper()
else:
async def sync_wrapper():
for message in messages:
yield message

return sync_wrapper()
30 changes: 30 additions & 0 deletions src/betterproto/grpc/grpcio_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict


if TYPE_CHECKING:
import grpc


class ServiceBase(ABC):

@property
@abstractmethod
def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: ...

@property
@abstractmethod
def __proto_path__(self) -> str: ...


def register_servicers(server: "grpc.aio.Server", *servicers: ServiceBase):
from grpc import method_handlers_generic_handler

server.add_generic_rpc_handlers(
tuple(
method_handlers_generic_handler(
servicer.__proto_path__, servicer.__rpc_methods__
)
for servicer in servicers
)
)
17 changes: 10 additions & 7 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
use_grpcio: bool = False
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)

@property
Expand Down Expand Up @@ -697,18 +698,20 @@ class ServiceMethodCompiler(ProtoContentBase):
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
comment_indent: int = 8
use_grpcio: bool = False

def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)

self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
self.output_file.imports_type_checking_only.add(
"from grpclib.metadata import Deadline"
)
if self.use_grpcio:
imports = ["import grpc.aio", "from betterproto.grpc.grpcio_client import MetadataLike"]
else:
imports = ["import grpclib.server", "from betterproto.grpc.grpclib_client import MetadataLike",
"from grpclib.metadata import Deadline"]

for import_line in imports:
self.output_file.imports_type_checking_only.add(import_line)

super().__post_init__() # check for unset fields

Expand Down
10 changes: 7 additions & 3 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)

USE_GRPCIO_FLAG = "USE_GRPCIO"


def traverse(
proto_file: FileDescriptorProto,
Expand Down Expand Up @@ -80,6 +81,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL

request_data = PluginRequestCompiler(plugin_request_obj=request)
use_grpcio = USE_GRPCIO_FLAG in plugin_options
# Gather output packages
for proto_file in request.proto_file:
output_package_name = proto_file.package
Expand All @@ -90,7 +92,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)

request_data.output_packages[output_package_name].use_grpcio = use_grpcio
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
Expand Down Expand Up @@ -143,7 +145,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(proto_input_file, service, index, output_package)
read_protobuf_service(proto_input_file, service, index, output_package, use_grpcio)

# Generate output files
output_paths: Set[pathlib.Path] = set()
Expand Down Expand Up @@ -253,6 +255,7 @@ def read_protobuf_service(
service: ServiceDescriptorProto,
index: int,
output_package: OutputTemplate,
use_grpcio: bool = False,
) -> None:
service_data = ServiceCompiler(
source_file=source_file,
Expand All @@ -266,4 +269,5 @@ def read_protobuf_service(
parent=service_data,
proto_obj=method,
path=[6, index, 2, j],
use_grpcio=use_grpcio,
)
10 changes: 9 additions & 1 deletion src/betterproto/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% i

{% endif %}


{% if output_file.use_grpcio %}
import grpc
from betterproto.grpc.grpcio_client import ServiceStub
from betterproto.grpc.grpcio_server import ServiceBase
{% endif %}

import betterproto
{% if output_file.services %}
{% if not output_file.use_grpcio %}
from betterproto.grpc.grpclib_client import ServiceStub
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
Expand Down
Loading