diff --git a/stubs/grpcio/@tests/test_cases/check_client_interceptor.py b/stubs/grpcio/@tests/test_cases/check_client_interceptor.py new file mode 100644 index 000000000000..2d4427fcf01e --- /dev/null +++ b/stubs/grpcio/@tests/test_cases/check_client_interceptor.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import AsyncIterable, AsyncIterator, Awaitable, Iterable, Iterator, TypeVar + +import grpc +import grpc.aio + +RequestT = TypeVar("RequestT") +ResponseT = TypeVar("ResponseT") + + +class NoopUnaryUnaryInterceptor(grpc.UnaryUnaryClientInterceptor): + def intercept_unary_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, RequestT], grpc._CallFuture[ResponseT]], + client_call_details: grpc.ClientCallDetails, + request: RequestT, + ) -> grpc._CallFuture[ResponseT]: + return continuation(client_call_details, request) + + +class NoopUnaryStreamInterceptor(grpc.UnaryStreamClientInterceptor): + def intercept_unary_stream( + self, + continuation: Callable[[grpc.ClientCallDetails, RequestT], grpc._CallIterator[ResponseT]], + client_call_details: grpc.ClientCallDetails, + request: RequestT, + ) -> grpc._CallIterator[ResponseT]: + return continuation(client_call_details, request) + + +class NoopStreamUnaryInterceptor(grpc.StreamUnaryClientInterceptor): + def intercept_stream_unary( + self, + continuation: Callable[[grpc.ClientCallDetails, Iterator[RequestT]], grpc._CallFuture[ResponseT]], + client_call_details: grpc.ClientCallDetails, + request_iterator: Iterator[RequestT], + ) -> grpc._CallFuture[ResponseT]: + return continuation(client_call_details, request_iterator) + + +class NoopStreamStreamInterceptor(grpc.StreamStreamClientInterceptor): + def intercept_stream_stream( + self, + continuation: Callable[[grpc.ClientCallDetails, Iterator[RequestT]], grpc._CallIterator[ResponseT]], + client_call_details: grpc.ClientCallDetails, + request_iterator: Iterator[RequestT], + ) -> grpc._CallIterator[ResponseT]: + return continuation(client_call_details, request_iterator) + + +channel = grpc.insecure_channel("target") +channel = grpc.intercept_channel( + channel, + NoopUnaryUnaryInterceptor(), + NoopUnaryStreamInterceptor(), + NoopStreamUnaryInterceptor(), + NoopStreamStreamInterceptor(), +) + + +class NoopAioUnaryUnaryInterceptor(grpc.aio.UnaryUnaryClientInterceptor): + async def intercept_unary_unary( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, RequestT], Awaitable[grpc.aio.UnaryUnaryCall[RequestT, ResponseT]]], + client_call_details: grpc.aio.ClientCallDetails, + request: RequestT, + ) -> ResponseT | grpc.aio.UnaryUnaryCall[RequestT, ResponseT]: + return await continuation(client_call_details, request) + + +class NoopAioUnaryStreamInterceptor(grpc.aio.UnaryStreamClientInterceptor): + async def intercept_unary_stream( + self, + continuation: Callable[[grpc.aio.ClientCallDetails, RequestT], Awaitable[grpc.aio.UnaryStreamCall[RequestT, ResponseT]]], + client_call_details: grpc.aio.ClientCallDetails, + request: RequestT, + ) -> AsyncIterator[ResponseT] | grpc.aio.UnaryStreamCall[RequestT, ResponseT]: + return await continuation(client_call_details, request) + + +class NoopAioStreamUnaryInterceptor(grpc.aio.StreamUnaryClientInterceptor): + async def intercept_stream_unary( + self, + continuation: Callable[ + [grpc.aio.ClientCallDetails, AsyncIterable[RequestT] | Iterable[RequestT]], + Awaitable[grpc.aio.StreamUnaryCall[RequestT, ResponseT]], + ], + client_call_details: grpc.aio.ClientCallDetails, + request_iterator: AsyncIterable[RequestT] | Iterable[RequestT], + ) -> ResponseT | grpc.aio.StreamUnaryCall[RequestT, ResponseT]: + return await continuation(client_call_details, request_iterator) + + +class NoopAioStreamStreamInterceptor(grpc.aio.StreamStreamClientInterceptor): + async def intercept_stream_stream( + self, + continuation: Callable[ + [grpc.aio.ClientCallDetails, AsyncIterable[RequestT] | Iterable[RequestT]], + Awaitable[grpc.aio.StreamStreamCall[RequestT, ResponseT]], + ], + client_call_details: grpc.aio.ClientCallDetails, + request_iterator: AsyncIterable[RequestT] | Iterable[RequestT], + ) -> AsyncIterator[ResponseT] | grpc.aio.StreamStreamCall[RequestT, ResponseT]: + return await continuation(client_call_details, request_iterator) + + +grpc.aio.insecure_channel( + "target", + interceptors=[ + NoopAioUnaryUnaryInterceptor(), + NoopAioUnaryStreamInterceptor(), + NoopAioStreamUnaryInterceptor(), + NoopAioStreamStreamInterceptor(), + ], +) diff --git a/stubs/grpcio/grpc/__init__.pyi b/stubs/grpcio/grpc/__init__.pyi index 27f969f0e79e..f9c8c2031ced 100644 --- a/stubs/grpcio/grpc/__init__.pyi +++ b/stubs/grpcio/grpc/__init__.pyi @@ -76,13 +76,10 @@ def secure_channel( ) -> Channel: ... _Interceptor: TypeAlias = ( - UnaryUnaryClientInterceptor[_TRequest, _TResponse] - | UnaryStreamClientInterceptor[_TRequest, _TResponse] - | StreamUnaryClientInterceptor[_TRequest, _TResponse] - | StreamStreamClientInterceptor[_TRequest, _TResponse] + UnaryUnaryClientInterceptor | UnaryStreamClientInterceptor | StreamUnaryClientInterceptor | StreamStreamClientInterceptor ) -def intercept_channel(channel: Channel, *interceptors: _Interceptor[_TRequest, _TResponse]) -> Channel: ... +def intercept_channel(channel: Channel, *interceptors: _Interceptor) -> Channel: ... # Create Client Credentials: @@ -378,25 +375,11 @@ class ClientCallDetails(abc.ABC): @type_check_only class _CallFuture(Call, Future[_TResponse], metaclass=abc.ABCMeta): ... -class UnaryUnaryClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): +class UnaryUnaryClientInterceptor(abc.ABC): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod def intercept_unary_unary( self, - # FIXME: decode these cryptic runes to confirm the typing mystery of - # this callable's signature that was left for us by past civilisations: - # - # continuation - A function that proceeds with the invocation by - # executing the next interceptor in chain or invoking the actual RPC - # on the underlying Channel. It is the interceptor's responsibility - # to call it if it decides to move the RPC forward. The interceptor - # can use response_future = continuation(client_call_details, - # request) to continue with the RPC. continuation returns an object - # that is both a Call for the RPC and a Future. In the event of RPC - # completion, the return Call-Future's result value will be the - # response message of the RPC. Should the event terminate with non-OK - # status, the returned Call-Future's exception value will be an - # RpcError. - # continuation: Callable[[ClientCallDetails, _TRequest], _CallFuture[_TResponse]], client_call_details: ClientCallDetails, request: _TRequest, @@ -407,7 +390,8 @@ class _CallIterator(Call, Generic[_TResponse], metaclass=abc.ABCMeta): def __iter__(self) -> Iterator[_TResponse]: ... def __next__(self) -> _TResponse: ... -class UnaryStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): +class UnaryStreamClientInterceptor(abc.ABC): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod def intercept_unary_stream( self, @@ -416,20 +400,22 @@ class UnaryStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): request: _TRequest, ) -> _CallIterator[_TResponse]: ... -class StreamUnaryClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): +class StreamUnaryClientInterceptor(abc.ABC): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod def intercept_stream_unary( self, - continuation: Callable[[ClientCallDetails, _TRequest], _CallFuture[_TResponse]], + continuation: Callable[[ClientCallDetails, Iterator[_TRequest]], _CallFuture[_TResponse]], client_call_details: ClientCallDetails, request_iterator: Iterator[_TRequest], ) -> _CallFuture[_TResponse]: ... -class StreamStreamClientInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): +class StreamStreamClientInterceptor(abc.ABC): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod def intercept_stream_stream( self, - continuation: Callable[[ClientCallDetails, _TRequest], _CallIterator[_TResponse]], + continuation: Callable[[ClientCallDetails, Iterator[_TRequest]], _CallIterator[_TResponse]], client_call_details: ClientCallDetails, request_iterator: Iterator[_TRequest], ) -> _CallIterator[_TResponse]: ... diff --git a/stubs/grpcio/grpc/aio/__init__.pyi b/stubs/grpcio/grpc/aio/__init__.pyi index 4a31deda9f58..e6400b1fdbc2 100644 --- a/stubs/grpcio/grpc/aio/__init__.pyi +++ b/stubs/grpcio/grpc/aio/__init__.pyi @@ -45,8 +45,6 @@ class AioRpcError(RpcError): # Create Client: -class ClientInterceptor(metaclass=abc.ABCMeta): ... - def insecure_channel( target: str, options: _Options | None = None, @@ -288,7 +286,7 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla def __await__(self) -> Generator[Incomplete, None, _TResponse]: ... def __init__( self, - interceptors: Sequence[UnaryUnaryClientInterceptor[_TRequest, _TResponse]], + interceptors: Sequence[UnaryUnaryClientInterceptor], request: _TRequest, timeout: float | None, metadata: Metadata, @@ -304,7 +302,7 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla # pylint: disable=too-many-arguments async def _invoke( self, - interceptors: Sequence[UnaryUnaryClientInterceptor[_TRequest, _TResponse]], + interceptors: Sequence[UnaryUnaryClientInterceptor], method: bytes, timeout: float | None, metadata: Metadata | None, @@ -316,42 +314,53 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla ) -> UnaryUnaryCall[_TRequest, _TResponse]: ... def time_remaining(self) -> float | None: ... -class UnaryUnaryClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class ClientInterceptor(metaclass=abc.ABCMeta): ... + +class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod async def intercept_unary_unary( self, # XXX: See equivalent function in grpc types for notes about continuation: - continuation: Callable[[ClientCallDetails, _TRequest], UnaryUnaryCall[_TRequest, _TResponse]], + continuation: Callable[[ClientCallDetails, _TRequest], Awaitable[UnaryUnaryCall[_TRequest, _TResponse]]], client_call_details: ClientCallDetails, request: _TRequest, - ) -> _TResponse: ... + ) -> _TResponse | UnaryUnaryCall[_TRequest, _TResponse]: ... -class UnaryStreamClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod async def intercept_unary_stream( self, - continuation: Callable[[ClientCallDetails, _TRequest], UnaryStreamCall[_TRequest, _TResponse]], + continuation: Callable[[ClientCallDetails, _TRequest], Awaitable[UnaryStreamCall[_TRequest, _TResponse]]], client_call_details: ClientCallDetails, request: _TRequest, - ) -> AsyncIterable[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ... + ) -> AsyncIterator[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ... -class StreamUnaryClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod async def intercept_stream_unary( self, - continuation: Callable[[ClientCallDetails, _TRequest], StreamUnaryCall[_TRequest, _TResponse]], + continuation: Callable[ + [ClientCallDetails, AsyncIterable[_TRequest] | Iterable[_TRequest]], Awaitable[StreamUnaryCall[_TRequest, _TResponse]] + ], client_call_details: ClientCallDetails, request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest], - ) -> AsyncIterable[_TResponse] | UnaryStreamCall[_TRequest, _TResponse]: ... + ) -> _TResponse | StreamUnaryCall[_TRequest, _TResponse]: ... -class StreamStreamClientInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class StreamStreamClientInterceptor(ClientInterceptor, metaclass=abc.ABCMeta): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod async def intercept_stream_stream( self, - continuation: Callable[[ClientCallDetails, _TRequest], StreamStreamCall[_TRequest, _TResponse]], + continuation: Callable[ + [ClientCallDetails, AsyncIterable[_TRequest] | Iterable[_TRequest]], + Awaitable[StreamStreamCall[_TRequest, _TResponse]], + ], client_call_details: ClientCallDetails, request_iterator: AsyncIterable[_TRequest] | Iterable[_TRequest], - ) -> AsyncIterable[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ... + ) -> AsyncIterator[_TResponse] | StreamStreamCall[_TRequest, _TResponse]: ... # Server-Side Interceptor: