66import abc
77import inspect
88from collections .abc import Awaitable , Callable
9- from typing import Any , Generic , Self , TypeVar , overload
9+ from typing import Any , Generic , Self , TypeVar , cast , overload
1010
1111from grpc .aio import AioRpcError , Channel
1212
1313from .channel import ChannelOptions , parse_grpc_uri
1414from .exception import ApiClientError , ClientNotConnected
1515
16- StubT = TypeVar ("StubT" )
17- """The type of the gRPC stub."""
1816
19-
20- class BaseApiClient (abc .ABC , Generic [StubT ]):
17+ class BaseApiClient (abc .ABC ):
2118 """A base class for API clients.
2219
2320 This class provides a common interface for API clients that communicate with a API
@@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]):
3229 a class that helps sending messages from a gRPC stream to
3330 a [Broadcast][frequenz.channels.Broadcast] channel.
3431
32+ Note:
33+ Because grpcio doesn't provide proper type hints, a hack is needed to have
34+ propepr async type hints for the stubs generated by protoc. When using
35+ `mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
36+ but in the `.pyi` file, so the type can be used to specify type hints, but
37+ **not** in any other context, as the class doesn't really exist for the Python
38+ interpreter. This include generics, and because of this, this class can't be
39+ even parametrized using the async class, so the instantiation of the stub can't
40+ be done in the base class.
41+
42+ Because of this, subclasses need to create the stubs by themselves, using the
43+ real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
44+ the async version of the stubs.
45+
46+ It is recommended to define a `stub` property that returns the async stub, so
47+ this hack is completely hidden from clients, even if they need to access the
48+ stub for more advanced uses.
49+
3550 Example:
3651 This example illustrates how to create a simple API client that connects to a
3752 gRPC server and calls a method on a stub.
3853
3954 ```python
4055 from collections.abc import AsyncIterable
56+ from typing import cast
4157 from frequenz.client.base.client import BaseApiClient, call_stub_method
4258 from frequenz.client.base.streaming import GrpcStreamBroadcaster
4359 from frequenz.channels import Receiver
@@ -59,23 +75,38 @@ async def example_method(
5975
6076 def example_stream(self) -> AsyncIterable[ExampleResponse]:
6177 ...
78+
79+ class ExampleAsyncStub:
80+ async def example_method(
81+ self,
82+ request: ExampleRequest # pylint: disable=unused-argument
83+ ) -> ExampleResponse:
84+ ...
85+
86+ def example_stream(self) -> AsyncIterable[ExampleResponse]:
87+ ...
6288 # End of generated classes
6389
6490 class ExampleResponseWrapper:
65- def __init__(self, response: ExampleResponse):
91+ def __init__(self, response: ExampleResponse) -> None :
6692 self.transformed_value = f"{response.float_value:.2f}"
6793
68- class MyApiClient(BaseApiClient[ExampleStub]):
69- def __init__(self, server_url: str, *, connect: bool = True):
70- super().__init__(
71- server_url, ExampleStub, connect=connect
94+ class MyApiClient(BaseApiClient):
95+ def __init__(self, server_url: str, *, connect: bool = True) -> None:
96+ super().__init__(server_url, connect=connect)
97+ self._stub = cast(
98+ ExampleAsyncStub, ExampleStub(self.channel)
7299 )
73100 self._broadcaster = GrpcStreamBroadcaster(
74101 "stream",
75102 lambda: self.stub.example_stream(ExampleRequest()),
76103 ExampleResponseWrapper,
77104 )
78105
106+ @property
107+ def stub(self) -> ExampleAsyncStub:
108+ return self._stub
109+
79110 async def example_method(
80111 self, int_value: int, str_value: str
81112 ) -> ExampleResponseWrapper:
@@ -114,7 +145,6 @@ async def main():
114145 def __init__ (
115146 self ,
116147 server_url : str ,
117- create_stub : Callable [[Channel ], StubT ],
118148 * ,
119149 connect : bool = True ,
120150 channel_defaults : ChannelOptions = ChannelOptions (),
@@ -123,7 +153,6 @@ def __init__(
123153
124154 Args:
125155 server_url: The URL of the server to connect to.
126- create_stub: A function that creates a stub from a channel.
127156 connect: Whether to connect to the server as soon as a client instance is
128157 created. If `False`, the client will not connect to the server until
129158 [connect()][frequenz.client.base.client.BaseApiClient.connect] is
@@ -132,10 +161,8 @@ def __init__(
132161 the server URL.
133162 """
134163 self ._server_url : str = server_url
135- self ._create_stub : Callable [[Channel ], StubT ] = create_stub
136164 self ._channel_defaults : ChannelOptions = channel_defaults
137165 self ._channel : Channel | None = None
138- self ._stub : StubT | None = None
139166 if connect :
140167 self .connect (server_url )
141168
@@ -165,22 +192,6 @@ def channel_defaults(self) -> ChannelOptions:
165192 """The default options for the gRPC channel."""
166193 return self ._channel_defaults
167194
168- @property
169- def stub (self ) -> StubT :
170- """The underlying gRPC stub.
171-
172- Warning:
173- This stub is provided as a last resort for advanced users. It is not
174- recommended to use this property directly unless you know what you are
175- doing and you don't care about being tied to a specific gRPC library.
176-
177- Raises:
178- ClientNotConnected: If the client is not connected to the server.
179- """
180- if self ._stub is None :
181- raise ClientNotConnected (server_url = self .server_url , operation = "stub" )
182- return self ._stub
183-
184195 @property
185196 def is_connected (self ) -> bool :
186197 """Whether the client is connected to the server."""
@@ -202,7 +213,6 @@ def connect(self, server_url: str | None = None) -> None:
202213 elif self .is_connected :
203214 return
204215 self ._channel = parse_grpc_uri (self ._server_url , self ._channel_defaults )
205- self ._stub = self ._create_stub (self ._channel )
206216
207217 async def disconnect (self ) -> None :
208218 """Disconnect from the server.
@@ -240,7 +250,7 @@ async def __aexit__(
240250
241251@overload
242252async def call_stub_method (
243- client : BaseApiClient [ StubT ] ,
253+ client : BaseApiClient ,
244254 stub_method : Callable [[], Awaitable [StubOutT ]],
245255 * ,
246256 method_name : str | None = None ,
@@ -250,7 +260,7 @@ async def call_stub_method(
250260
251261@overload
252262async def call_stub_method (
253- client : BaseApiClient [ StubT ] ,
263+ client : BaseApiClient ,
254264 stub_method : Callable [[], Awaitable [StubOutT ]],
255265 * ,
256266 method_name : str | None = None ,
@@ -261,7 +271,7 @@ async def call_stub_method(
261271# We need the `noqa: DOC503` because `pydoclint` can't figure out that
262272# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
263273async def call_stub_method ( # noqa: DOC503
264- client : BaseApiClient [ StubT ] ,
274+ client : BaseApiClient ,
265275 stub_method : Callable [[], Awaitable [StubOutT ]],
266276 * ,
267277 method_name : str | None = None ,
0 commit comments