Skip to content

Commit c137997

Browse files
committed
feat: add grpcio integration
1 parent 914f37e commit c137997

File tree

7 files changed

+539
-117
lines changed

7 files changed

+539
-117
lines changed

aioinject/ext/grpcio.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from __future__ import annotations
2+
3+
import contextvars
4+
from typing import TYPE_CHECKING
5+
6+
from grpc import ( # type: ignore[import-untyped]
7+
HandlerCallDetails,
8+
RpcMethodHandler,
9+
ServicerContext,
10+
stream_stream_rpc_method_handler,
11+
stream_unary_rpc_method_handler,
12+
unary_stream_rpc_method_handler,
13+
unary_unary_rpc_method_handler,
14+
)
15+
from grpc.aio import ServerInterceptor # type: ignore[import-untyped]
16+
17+
from aioinject._types import (
18+
P,
19+
T,
20+
)
21+
from aioinject.decorators import base_inject
22+
23+
24+
if TYPE_CHECKING:
25+
from collections.abc import AsyncIterator, Awaitable, Callable
26+
27+
from google.protobuf.message import Message # type: ignore[import-untyped]
28+
29+
from aioinject import Container, Context
30+
31+
__all__ = ["AioInjectInterceptor", "inject"]
32+
33+
34+
def inject(function: Callable[P, T]) -> Callable[P, T]:
35+
return base_inject(
36+
function,
37+
context_parameters=(),
38+
context_getter=lambda args, kwargs: _context_var.get(), # noqa: ARG005
39+
)
40+
41+
42+
_context_var: contextvars.ContextVar[Context] = contextvars.ContextVar(
43+
"aioinject.grpcio.context"
44+
)
45+
46+
47+
class AioInjectInterceptor(ServerInterceptor): # type: ignore[misc]
48+
def __init__(self, container: Container) -> None:
49+
self._container = container
50+
51+
async def intercept_service( # noqa: C901
52+
self,
53+
continuation: Callable[
54+
[HandlerCallDetails],
55+
Awaitable[RpcMethodHandler],
56+
],
57+
handler_call_details: HandlerCallDetails,
58+
) -> RpcMethodHandler:
59+
handler = await continuation(handler_call_details)
60+
deserializer = handler.request_deserializer
61+
serializer = handler.response_serializer
62+
63+
if handler.unary_unary:
64+
65+
async def unary_unary_behavior(
66+
request: Message, context: ServicerContext
67+
) -> object:
68+
async with self._container.context() as di_context:
69+
_context_var.set(di_context)
70+
return await handler.unary_unary(request, context)
71+
72+
return unary_unary_rpc_method_handler(
73+
behavior=unary_unary_behavior,
74+
request_deserializer=deserializer,
75+
response_serializer=serializer,
76+
)
77+
if handler.unary_stream:
78+
79+
async def unary_stream_behavior(
80+
request: Message, context: ServicerContext
81+
) -> AsyncIterator[object]:
82+
async with self._container.context() as di_context:
83+
_context_var.set(di_context)
84+
async for message in handler.unary_stream(
85+
request, context
86+
):
87+
yield message
88+
89+
return unary_stream_rpc_method_handler(
90+
unary_stream_behavior,
91+
request_deserializer=deserializer,
92+
response_serializer=serializer,
93+
)
94+
if handler.stream_unary:
95+
96+
async def stream_unary_behavior(
97+
request: AsyncIterator[Message], context: ServicerContext
98+
) -> object:
99+
_context_var.set(self._container.root)
100+
return await handler.stream_unary(request, context)
101+
102+
return stream_unary_rpc_method_handler(
103+
stream_unary_behavior,
104+
request_deserializer=deserializer,
105+
response_serializer=serializer,
106+
)
107+
108+
if handler.stream_stream:
109+
110+
async def stream_stream_behavior(
111+
request: AsyncIterator[Message], context: ServicerContext
112+
) -> AsyncIterator[object]:
113+
_context_var.set(self._container.root)
114+
async for message in handler.stream_stream(request, context):
115+
yield message
116+
117+
return stream_stream_rpc_method_handler(
118+
stream_stream_behavior,
119+
request_deserializer=deserializer,
120+
response_serializer=serializer,
121+
)
122+
return handler # pragma: no cover

pyproject.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,28 @@ benchmark = [
3232
]
3333
dev = [
3434
"mypy>=1.17.1",
35-
"ruff>=0.12.9",
35+
"ruff>=0.12.10",
3636
]
3737
docs = [
3838
"mkdocs>=1.6.1",
39-
"mkdocs-material>=9.6.17",
39+
"mkdocs-material>=9.6.18",
4040
]
4141
test = [
4242
"aiogram>=3.22.0",
4343
"anyio>=4.10.0",
4444
"asgi-lifespan>=2.1.0",
45-
"coverage[toml]>=7.10.4",
45+
"coverage[toml]>=7.10.5",
4646
"django>=5.2.5",
4747
"djangorestframework>=3.16.1",
4848
"fastapi>=0.116.1",
49+
"grpcio>=1.75.0",
50+
"grpcio-tools>=1.75.0",
4951
"litestar>=2.17.0",
5052
"pydantic-settings>=2.10.1",
5153
"pytest>=8.4.1",
5254
"pytest-cov>=6.2.1",
5355
"pytest-django>=4.11.1",
54-
"strawberry-graphql>=0.278.1",
56+
"strawberry-graphql>=0.280.0",
5557
"trio>=0.30.0",
5658
"uvicorn>=0.35.0",
5759
]
@@ -152,6 +154,7 @@ ignore = [
152154
"S101",
153155
"S311", # random
154156
]
157+
"tests/integrations/grpcio/*" = ["ARG002", "N802"]
155158
"tests/*/test_*.py" = ["FBT001"]
156159
"docs/*" = [
157160
"INP001", # Implicit package (no __init__.py)

tests/integrations/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66

77
import aioinject
8+
from aioinject import Context, FromContext
9+
from aioinject.scope import CurrentScope
810

911

1012
@pytest.fixture
@@ -33,6 +35,7 @@ def get_node() -> ScopedNode:
3335
@pytest.fixture
3436
def container(provided_value: int) -> aioinject.Container:
3537
container = aioinject.Container()
38+
container.register(FromContext(Context, scope=CurrentScope()))
3639
container.register(aioinject.Object(provided_value))
3740
container.register(aioinject.Scoped(NumberService))
3841
container.register(aioinject.Scoped(get_node))

tests/integrations/grpcio/__init__.py

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
syntax = "proto3";
2+
3+
message Request {
4+
string field = 1;
5+
}
6+
7+
message Response {
8+
string field = 1;
9+
}
10+
11+
service Service {
12+
rpc Unary(Request) returns (Response);
13+
rpc UnaryStream(Request) returns (stream Response);
14+
rpc StreamUnary(stream Request) returns (Response);
15+
rpc StreamStream(stream Request) returns (stream Response);
16+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import uuid
2+
from collections.abc import AsyncIterator
3+
from pathlib import Path
4+
from typing import Any, cast
5+
6+
import grpc # type: ignore[import-untyped]
7+
import pytest
8+
from _pytest.fixtures import SubRequest
9+
from grpc import ServicerContext
10+
from grpc.aio import Server # type: ignore[import-untyped]
11+
12+
from aioinject import Container, Context, Injected, Scope
13+
from aioinject.ext.grpcio import AioInjectInterceptor, inject
14+
15+
16+
protos: Any
17+
services: Any
18+
protos, services = grpc.protos_and_services(
19+
str(
20+
Path(__file__).parent.joinpath("service.proto").relative_to(Path.cwd())
21+
)
22+
)
23+
24+
25+
@pytest.fixture(scope="session", autouse=True, params=["asyncio"])
26+
def anyio_backend(request: SubRequest) -> str:
27+
return cast("str", request.param)
28+
29+
30+
class Service(services.Service): # type: ignore[misc]
31+
@inject
32+
async def Unary(
33+
self,
34+
request: protos.Request,
35+
context: ServicerContext,
36+
value: Injected[int],
37+
di_context: Injected[Context],
38+
) -> protos.Response:
39+
assert di_context.scope is Scope.request
40+
return protos.Response(field=f"{request.field} {value}")
41+
42+
@inject
43+
async def UnaryStream(
44+
self,
45+
request: protos.Request,
46+
context: ServicerContext,
47+
value: Injected[int],
48+
di_context: Injected[Context],
49+
) -> AsyncIterator[protos.Response]:
50+
assert di_context.scope is Scope.request
51+
52+
for i in range(10):
53+
yield protos.Response(field=f"{request.field} {value} {i}")
54+
55+
@inject
56+
async def StreamUnary(
57+
self,
58+
request: AsyncIterator[protos.Request],
59+
context: ServicerContext,
60+
value: Injected[int],
61+
di_context: Injected[Context],
62+
) -> protos.Response:
63+
assert di_context.scope == Scope.lifetime
64+
65+
values = [message.field async for message in request]
66+
values.append(str(value))
67+
return protos.Response(field=" ".join(values))
68+
69+
@inject
70+
async def StreamStream(
71+
self,
72+
request: AsyncIterator[protos.Request],
73+
context: ServicerContext,
74+
value: Injected[int],
75+
di_context: Injected[Context],
76+
) -> AsyncIterator[protos.Response]:
77+
assert di_context.scope == Scope.lifetime
78+
79+
async for message in request:
80+
yield protos.Response(field=f"{message.field} {value}")
81+
82+
83+
@pytest.fixture
84+
async def grpcio_server(container: Container) -> AsyncIterator[Server]:
85+
server = grpc.aio.server(interceptors=[AioInjectInterceptor(container)])
86+
services.add_ServiceServicer_to_server(Service(), server)
87+
server.add_insecure_port("localhost:50051")
88+
await server.start()
89+
yield server
90+
await server.stop(0)
91+
92+
93+
@pytest.fixture
94+
async def grpcio_client(
95+
grpcio_server: object, # noqa: ARG001
96+
) -> AsyncIterator[services.ServiceStub]:
97+
async with (
98+
grpc.aio.insecure_channel("localhost:50051") as channel,
99+
):
100+
yield services.ServiceStub(channel)
101+
102+
103+
async def test_unary_unary_ok(
104+
grpcio_client: services.ServiceStub, provided_value: int
105+
) -> None:
106+
field = str(uuid.uuid4())
107+
response = await grpcio_client.Unary(protos.Request(field=field))
108+
assert response.field == f"{field} {provided_value}"
109+
110+
111+
async def test_unary_stream_ok(
112+
grpcio_client: services.ServiceStub, provided_value: int
113+
) -> None:
114+
field = str(uuid.uuid4())
115+
messages = [
116+
message
117+
async for message in grpcio_client.UnaryStream(
118+
protos.Request(field=field)
119+
)
120+
]
121+
for number, message in enumerate(messages):
122+
assert message.field == f"{field} {provided_value} {number}"
123+
124+
125+
async def test_stream_unary_ok(
126+
grpcio_client: services.ServiceStub, provided_value: int
127+
) -> None:
128+
fields = [str(uuid.uuid4()) for _ in range(10)]
129+
call = grpcio_client.StreamUnary()
130+
for field in fields:
131+
await call.write(protos.Request(field=field))
132+
await call.done_writing()
133+
134+
response = await call
135+
assert response.field == " ".join([*fields, str(provided_value)])
136+
137+
138+
async def test_stream_stream_ok(
139+
grpcio_client: services.ServiceStub,
140+
provided_value: int,
141+
) -> None:
142+
call = grpcio_client.StreamStream()
143+
for _ in range(10):
144+
field = str(uuid.uuid4())
145+
await call.write(protos.Request(field=field))
146+
response = await call.read()
147+
assert response.field == f"{field} {provided_value}"

0 commit comments

Comments
 (0)