From 3e391dccbc463f9495e5c055da44d80f87c30e14 Mon Sep 17 00:00:00 2001 From: Robin McCorkell Date: Thu, 25 Sep 2025 22:03:22 +0100 Subject: [PATCH] grpcio: improve server interceptor typing --- stubs/grpcio/@tests/test_cases/check_aio.py | 4 +-- .../test_cases/check_handler_inheritance.py | 6 ++-- .../test_cases/check_server_interceptor.py | 31 +++++++++++++------ stubs/grpcio/grpc/__init__.pyi | 18 ++++++----- stubs/grpcio/grpc/aio/__init__.pyi | 9 +++--- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/stubs/grpcio/@tests/test_cases/check_aio.py b/stubs/grpcio/@tests/test_cases/check_aio.py index 2eef7eec05f4..7338268ed3f8 100644 --- a/stubs/grpcio/@tests/test_cases/check_aio.py +++ b/stubs/grpcio/@tests/test_cases/check_aio.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, cast +from typing import cast from typing_extensions import assert_type import grpc.aio @@ -9,7 +9,7 @@ client_interceptors: list[grpc.aio.ClientInterceptor] = [] grpc.aio.insecure_channel("target", interceptors=client_interceptors) -server_interceptors: list[grpc.aio.ServerInterceptor[Any, Any]] = [] +server_interceptors: list[grpc.aio.ServerInterceptor] = [] grpc.aio.server(interceptors=server_interceptors) diff --git a/stubs/grpcio/@tests/test_cases/check_handler_inheritance.py b/stubs/grpcio/@tests/test_cases/check_handler_inheritance.py index 72cefce0bd41..21826a18b18f 100644 --- a/stubs/grpcio/@tests/test_cases/check_handler_inheritance.py +++ b/stubs/grpcio/@tests/test_cases/check_handler_inheritance.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from typing_extensions import assert_type import grpc @@ -19,11 +19,11 @@ def unary_unary_call(rq: Request, ctx: grpc.ServicerContext) -> Response: return Response() -class ServiceHandler(grpc.ServiceRpcHandler[Request, Response]): +class ServiceHandler(grpc.ServiceRpcHandler): def service_name(self) -> str: return "hello" - def service(self, handler_call_details: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler[Request, Response] | None: + def service(self, handler_call_details: grpc.HandlerCallDetails) -> grpc.RpcMethodHandler[Any, Any] | None: rpc = grpc.RpcMethodHandler[Request, Response]() rpc.unary_unary = unary_unary_call return rpc diff --git a/stubs/grpcio/@tests/test_cases/check_server_interceptor.py b/stubs/grpcio/@tests/test_cases/check_server_interceptor.py index 9a84f37e41d9..d5ef592a61ad 100644 --- a/stubs/grpcio/@tests/test_cases/check_server_interceptor.py +++ b/stubs/grpcio/@tests/test_cases/check_server_interceptor.py @@ -1,22 +1,35 @@ from __future__ import annotations from collections.abc import Callable +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Awaitable, TypeVar import grpc +import grpc.aio +RequestT = TypeVar("RequestT") +ResponseT = TypeVar("ResponseT") -class Request: - pass +class NoopInterceptor(grpc.ServerInterceptor): + def intercept_service( + self, + continuation: Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[RequestT, ResponseT] | None], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler[RequestT, ResponseT] | None: + return continuation(handler_call_details) -class Response: - pass +grpc.server(interceptors=[NoopInterceptor()], thread_pool=ThreadPoolExecutor()) -class NoopInterceptor(grpc.ServerInterceptor[Request, Response]): - def intercept_service( + +class NoopAioInterceptor(grpc.aio.ServerInterceptor): + async def intercept_service( self, - continuation: Callable[[grpc.HandlerCallDetails], grpc.RpcMethodHandler[Request, Response] | None], + continuation: Callable[[grpc.HandlerCallDetails], Awaitable[grpc.RpcMethodHandler[RequestT, ResponseT]]], handler_call_details: grpc.HandlerCallDetails, - ) -> grpc.RpcMethodHandler[Request, Response] | None: - return continuation(handler_call_details) + ) -> grpc.RpcMethodHandler[RequestT, ResponseT]: + return await continuation(handler_call_details) + + +grpc.aio.server(interceptors=[NoopAioInterceptor()]) diff --git a/stubs/grpcio/grpc/__init__.pyi b/stubs/grpcio/grpc/__init__.pyi index c56ab90d0cf6..27f969f0e79e 100644 --- a/stubs/grpcio/grpc/__init__.pyi +++ b/stubs/grpcio/grpc/__init__.pyi @@ -108,8 +108,8 @@ def composite_channel_credentials( def server( thread_pool: futures.ThreadPoolExecutor, - handlers: list[GenericRpcHandler[Any, Any]] | None = None, - interceptors: list[ServerInterceptor[Any, Any]] | None = None, + handlers: list[GenericRpcHandler] | None = None, + interceptors: list[ServerInterceptor] | None = None, options: _Options | None = None, maximum_concurrent_rpcs: int | None = None, compression: Compression | None = None, @@ -173,7 +173,7 @@ def stream_stream_rpc_method_handler( ) -> RpcMethodHandler[_TRequest, _TResponse]: ... def method_handlers_generic_handler( service: str, method_handlers: dict[str, RpcMethodHandler[Any, Any]] -) -> GenericRpcHandler[Any, Any]: ... +) -> GenericRpcHandler: ... # Channel Ready Future: @@ -264,7 +264,7 @@ class Channel(abc.ABC): class Server(abc.ABC): @abc.abstractmethod - def add_generic_rpc_handlers(self, generic_rpc_handlers: Iterable[GenericRpcHandler[Any, Any]]) -> None: ... + def add_generic_rpc_handlers(self, generic_rpc_handlers: Iterable[GenericRpcHandler]) -> None: ... # Returns an integer port on which server will accept RPC requests. @abc.abstractmethod @@ -493,17 +493,19 @@ class HandlerCallDetails(abc.ABC): method: str invocation_metadata: _Metadata -class GenericRpcHandler(abc.ABC, Generic[_TRequest, _TResponse]): +class GenericRpcHandler(abc.ABC): + # The return type depends on the handler call details. @abc.abstractmethod - def service(self, handler_call_details: HandlerCallDetails) -> RpcMethodHandler[_TRequest, _TResponse] | None: ... + def service(self, handler_call_details: HandlerCallDetails) -> RpcMethodHandler[Any, Any] | None: ... -class ServiceRpcHandler(GenericRpcHandler[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class ServiceRpcHandler(GenericRpcHandler, metaclass=abc.ABCMeta): @abc.abstractmethod def service_name(self) -> str: ... # Service-Side Interceptor: -class ServerInterceptor(abc.ABC, Generic[_TRequest, _TResponse]): +class ServerInterceptor(abc.ABC): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod def intercept_service( self, diff --git a/stubs/grpcio/grpc/aio/__init__.pyi b/stubs/grpcio/grpc/aio/__init__.pyi index 092f87080a72..4a31deda9f58 100644 --- a/stubs/grpcio/grpc/aio/__init__.pyi +++ b/stubs/grpcio/grpc/aio/__init__.pyi @@ -65,8 +65,8 @@ def secure_channel( def server( migration_thread_pool: futures.Executor | None = None, - handlers: Sequence[GenericRpcHandler[Any, Any]] | None = None, - interceptors: Sequence[ServerInterceptor[Any, Any]] | None = None, + handlers: Sequence[GenericRpcHandler] | None = None, + interceptors: Sequence[ServerInterceptor] | None = None, options: _Options | None = None, maximum_concurrent_rpcs: int | None = None, compression: Compression | None = None, @@ -125,7 +125,7 @@ class Channel(abc.ABC): class Server(metaclass=abc.ABCMeta): @abc.abstractmethod - def add_generic_rpc_handlers(self, generic_rpc_handlers: Iterable[GenericRpcHandler[Any, Any]]) -> None: ... + def add_generic_rpc_handlers(self, generic_rpc_handlers: Iterable[GenericRpcHandler]) -> None: ... # Returns an integer port on which server will accept RPC requests. @abc.abstractmethod @@ -355,7 +355,8 @@ class StreamStreamClientInterceptor(Generic[_TRequest, _TResponse], metaclass=ab # Server-Side Interceptor: -class ServerInterceptor(Generic[_TRequest, _TResponse], metaclass=abc.ABCMeta): +class ServerInterceptor(metaclass=abc.ABCMeta): + # This method (not the class) is generic over _TRequest and _TResponse. @abc.abstractmethod async def intercept_service( self,