Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions stubs/grpcio/@tests/test_cases/check_client_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

from collections.abc import Callable
from typing import AsyncIterable, AsyncIterator, Awaitable, Iterable, Iterator, TypeVar

import grpc
import grpc.aio

RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")


class NoopUnaryUnaryInterceptor(grpc.UnaryUnaryClientInterceptor):
def intercept_unary_unary(
self,
continuation: Callable[[grpc.ClientCallDetails, RequestT], grpc._CallFuture[ResponseT]],
client_call_details: grpc.ClientCallDetails,
request: RequestT,
) -> grpc._CallFuture[ResponseT]:
return continuation(client_call_details, request)


class NoopUnaryStreamInterceptor(grpc.UnaryStreamClientInterceptor):
def intercept_unary_stream(
self,
continuation: Callable[[grpc.ClientCallDetails, RequestT], grpc._CallIterator[ResponseT]],
client_call_details: grpc.ClientCallDetails,
request: RequestT,
) -> grpc._CallIterator[ResponseT]:
return continuation(client_call_details, request)


class NoopStreamUnaryInterceptor(grpc.StreamUnaryClientInterceptor):
def intercept_stream_unary(
self,
continuation: Callable[[grpc.ClientCallDetails, Iterator[RequestT]], grpc._CallFuture[ResponseT]],
client_call_details: grpc.ClientCallDetails,
request_iterator: Iterator[RequestT],
) -> grpc._CallFuture[ResponseT]:
return continuation(client_call_details, request_iterator)


class NoopStreamStreamInterceptor(grpc.StreamStreamClientInterceptor):
def intercept_stream_stream(
self,
continuation: Callable[[grpc.ClientCallDetails, Iterator[RequestT]], grpc._CallIterator[ResponseT]],
client_call_details: grpc.ClientCallDetails,
request_iterator: Iterator[RequestT],
) -> grpc._CallIterator[ResponseT]:
return continuation(client_call_details, request_iterator)


channel = grpc.insecure_channel("target")
channel = grpc.intercept_channel(
channel,
NoopUnaryUnaryInterceptor(),
NoopUnaryStreamInterceptor(),
NoopStreamUnaryInterceptor(),
NoopStreamStreamInterceptor(),
)


class NoopAioUnaryUnaryInterceptor(grpc.aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(
self,
continuation: Callable[[grpc.aio.ClientCallDetails, RequestT], Awaitable[grpc.aio.UnaryUnaryCall[RequestT, ResponseT]]],
client_call_details: grpc.aio.ClientCallDetails,
request: RequestT,
) -> ResponseT | grpc.aio.UnaryUnaryCall[RequestT, ResponseT]:
return await continuation(client_call_details, request)


class NoopAioUnaryStreamInterceptor(grpc.aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(
self,
continuation: Callable[[grpc.aio.ClientCallDetails, RequestT], Awaitable[grpc.aio.UnaryStreamCall[RequestT, ResponseT]]],
client_call_details: grpc.aio.ClientCallDetails,
request: RequestT,
) -> AsyncIterator[ResponseT] | grpc.aio.UnaryStreamCall[RequestT, ResponseT]:
return await continuation(client_call_details, request)


class NoopAioStreamUnaryInterceptor(grpc.aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(
self,
continuation: Callable[
[grpc.aio.ClientCallDetails, AsyncIterable[RequestT] | Iterable[RequestT]],
Awaitable[grpc.aio.StreamUnaryCall[RequestT, ResponseT]],
],
client_call_details: grpc.aio.ClientCallDetails,
request_iterator: AsyncIterable[RequestT] | Iterable[RequestT],
) -> ResponseT | grpc.aio.StreamUnaryCall[RequestT, ResponseT]:
return await continuation(client_call_details, request_iterator)


class NoopAioStreamStreamInterceptor(grpc.aio.StreamStreamClientInterceptor):
async def intercept_stream_stream(
self,
continuation: Callable[
[grpc.aio.ClientCallDetails, AsyncIterable[RequestT] | Iterable[RequestT]],
Awaitable[grpc.aio.StreamStreamCall[RequestT, ResponseT]],
],
client_call_details: grpc.aio.ClientCallDetails,
request_iterator: AsyncIterable[RequestT] | Iterable[RequestT],
) -> AsyncIterator[ResponseT] | grpc.aio.StreamStreamCall[RequestT, ResponseT]:
return await continuation(client_call_details, request_iterator)


grpc.aio.insecure_channel(
"target",
interceptors=[
NoopAioUnaryUnaryInterceptor(),
NoopAioUnaryStreamInterceptor(),
NoopAioStreamUnaryInterceptor(),
NoopAioStreamStreamInterceptor(),
],
)
38 changes: 12 additions & 26 deletions stubs/grpcio/grpc/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@ def secure_channel(
) -> Channel: ...

_Interceptor: TypeAlias = (
UnaryUnaryClientInterceptor[_TRequest, _TResponse]
| UnaryStreamClientInterceptor[_TRequest, _TResponse]
| StreamUnaryClientInterceptor[_TRequest, _TResponse]
| StreamStreamClientInterceptor[_TRequest, _TResponse]
UnaryUnaryClientInterceptor | UnaryStreamClientInterceptor | StreamUnaryClientInterceptor | StreamStreamClientInterceptor
)

def intercept_channel(channel: Channel, *interceptors: _Interceptor[_TRequest, _TResponse]) -> Channel: ...
def intercept_channel(channel: Channel, *interceptors: _Interceptor) -> Channel: ...

# Create Client Credentials:

Expand Down Expand Up @@ -378,25 +375,11 @@ class ClientCallDetails(abc.ABC):
@type_check_only
class _CallFuture(Call, Future[_TResponse], metaclass=abc.ABCMeta): ...

class UnaryUnaryClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]):
class UnaryUnaryClientInterceptor(abc.ABC):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
def intercept_unary_unary(
self,
# FIXME: decode these cryptic runes to confirm the typing mystery of
# this callable's signature that was left for us by past civilisations:
#
# continuation - A function that proceeds with the invocation by
# executing the next interceptor in chain or invoking the actual RPC
# on the underlying Channel. It is the interceptor's responsibility
# to call it if it decides to move the RPC forward. The interceptor
# can use response_future = continuation(client_call_details,
# request) to continue with the RPC. continuation returns an object
# that is both a Call for the RPC and a Future. In the event of RPC
# completion, the return Call-Future's result value will be the
# response message of the RPC. Should the event terminate with non-OK
# status, the returned Call-Future's exception value will be an
# RpcError.
#
continuation: Callable[[ClientCallDetails, _TRequest], _CallFuture[_TResponse]],
client_call_details: ClientCallDetails,
request: _TRequest,
Expand All @@ -407,7 +390,8 @@ class _CallIterator(Call, Generic[_TResponse], metaclass=abc.ABCMeta):
def __iter__(self) -> Iterator[_TResponse]: ...
def __next__(self) -> _TResponse: ...

class UnaryStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]):
class UnaryStreamClientInterceptor(abc.ABC):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
def intercept_unary_stream(
self,
Expand All @@ -416,20 +400,22 @@ class UnaryStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]):
request: _TRequest,
) -> _CallIterator[_TResponse]: ...

class StreamUnaryClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]):
class StreamUnaryClientInterceptor(abc.ABC):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
def intercept_stream_unary(
self,
continuation: Callable[[ClientCallDetails, _TRequest], _CallFuture[_TResponse]],
continuation: Callable[[ClientCallDetails, Iterator[_TRequest]], _CallFuture[_TResponse]],
client_call_details: ClientCallDetails,
request_iterator: Iterator[_TRequest],
) -> _CallFuture[_TResponse]: ...

class StreamStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]):
class StreamStreamClientInterceptor(abc.ABC):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
def intercept_stream_stream(
self,
continuation: Callable[[ClientCallDetails, _TRequest], _CallIterator[_TResponse]],
continuation: Callable[[ClientCallDetails, Iterator[_TRequest]], _CallIterator[_TResponse]],
client_call_details: ClientCallDetails,
request_iterator: Iterator[_TRequest],
) -> _CallIterator[_TResponse]: ...
Expand Down
41 changes: 25 additions & 16 deletions stubs/grpcio/grpc/aio/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class AioRpcError(RpcError):

# Create Client:

class ClientInterceptor(metaclass=abc.ABCMeta): ...

def insecure_channel(
target: str,
options: _Options | None = None,
Expand Down Expand Up @@ -288,7 +286,7 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla
def __await__(self) -> Generator[Incomplete, None, _TResponse]: ...
def __init__(
self,
interceptors: Sequence[UnaryUnaryClientInterceptor[_TRequest, _TResponse]],
interceptors: Sequence[UnaryUnaryClientInterceptor],
request: _TRequest,
timeout: float | None,
metadata: Metadata,
Expand All @@ -304,7 +302,7 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla
# pylint: disable=too-many-arguments
async def _invoke(
self,
interceptors: Sequence[UnaryUnaryClientInterceptor[_TRequest, _TResponse]],
interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes,
timeout: float | None,
metadata: Metadata | None,
Expand All @@ -316,42 +314,53 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla
) -> UnaryUnaryCall[_TRequest, _TResponse]: ...
def time_remaining(self) -> float | None: ...

class UnaryUnaryClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta):
class ClientInterceptor(metaclass=abc.ABCMeta): ...

class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
async def intercept_unary_unary(
self,
# XXX: See equivalent function in grpc types for notes about continuation:
continuation: Callable[[ClientCallDetails, _TRequest], UnaryUnaryCall[_TRequest, _TResponse]],
continuation: Callable[[ClientCallDetails, _TRequest], Awaitable[UnaryUnaryCall[_TRequest, _TResponse]]],
client_call_details: ClientCallDetails,
request: _TRequest,
) -> _TResponse: ...
) -> _TResponse | UnaryUnaryCall[_TRequest, _TResponse]: ...

class UnaryStreamClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta):
class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
async def intercept_unary_stream(
self,
continuation: Callable[[ClientCallDetails, _TRequest], UnaryStreamCall[_TRequest, _TResponse]],
continuation: Callable[[ClientCallDetails, _TRequest], Awaitable[UnaryStreamCall[_TRequest, _TResponse]]],
client_call_details: ClientCallDetails,
request: _TRequest,
) -> AsyncIterable[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ...
) -> AsyncIterator[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ...

class StreamUnaryClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta):
class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
async def intercept_stream_unary(
self,
continuation: Callable[[ClientCallDetails, _TRequest], StreamUnaryCall[_TRequest, _TResponse]],
continuation: Callable[
[ClientCallDetails, AsyncIterable[_TRequest] | Iterable[_TRequest]], Awaitable[StreamUnaryCall[_TRequest, _TResponse]]
],
client_call_details: ClientCallDetails,
request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest],
) -> AsyncIterable[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ...
) -> _TResponse | StreamUnaryCall[_TRequest, _TResponse]: ...

class StreamStreamClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta):
class StreamStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta):
# This method (not the class) is generic over _TRequest and _TResponse.
@abc.abstractmethod
async def intercept_stream_stream(
self,
continuation: Callable[[ClientCallDetails, _TRequest], StreamStreamCall[_TRequest, _TResponse]],
continuation: Callable[
[ClientCallDetails, AsyncIterable[_TRequest] | Iterable[_TRequest]],
Awaitable[StreamStreamCall[_TRequest, _TResponse]],
],
client_call_details: ClientCallDetails,
request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest],
) -> AsyncIterable[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ...
) -> AsyncIterator[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ...

# Server-Side Interceptor:

Expand Down