Skip to content

Commit cf64f31

Browse files
committed
Remove generic type from BaseApiClient
The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 92028a8 commit cf64f31

File tree

3 files changed

+78
-65
lines changed

3 files changed

+78
-65
lines changed

RELEASE_NOTES.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@
1414

1515
* The `ExponentialBackoff` and `LinearBackoff` classes now require keyword arguments for their constructor. This change was made to make the classes easier to use and to avoid confusion with the order of the arguments.
1616

17+
* The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. To convert you client:
18+
19+
```python
20+
# Old
21+
from my_service_pb2_grpc import MyServiceStub
22+
class MyApiClient(BaseApiClient[MyServiceStub]):
23+
def __init__(self, server_url: str, *, ...) -> None:
24+
super().__init__(server_url, MyServiceStub, ...)
25+
...
26+
27+
# New
28+
from typing import cast
29+
from my_service_pb2_grpc import MyServiceStub, MyServiceAsyncStub
30+
class MyApiClient(BaseApiClient):
31+
def __init__(self, server_url: str, *, ...) -> None:
32+
super().__init__(server_url, connect=connect)
33+
self._stub = cast(MyServiceAsyncStub, MyServiceStub(self.channel))
34+
...
35+
36+
@property
37+
def stub(self) -> MyServiceAsyncStub:
38+
if self._channel is None:
39+
raise ClientNotConnected(server_url=self.server_url, operation="stub")
40+
return self._stub
41+
```
42+
43+
After this, you should be able to remove a lot of `cast`s or `type: ignore` from the code when calling the stub `async` methods.
44+
1745
## New Features
1846

1947
<!-- Here goes the main new features and examples or instructions on how to use them -->

src/frequenz/client/base/client.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
import abc
77
import inspect
88
from collections.abc import Awaitable, Callable
9-
from typing import Any, Generic, Self, TypeVar, overload
9+
from typing import Any, Self, TypeVar, overload
1010

1111
from grpc.aio import AioRpcError, Channel
1212

1313
from .channel import ChannelOptions, parse_grpc_uri
1414
from .exception import ApiClientError, ClientNotConnected
1515

16-
StubT = TypeVar("StubT")
17-
"""The type of the gRPC stub."""
1816

19-
20-
class BaseApiClient(abc.ABC, Generic[StubT]):
17+
class BaseApiClient(abc.ABC):
2118
"""A base class for API clients.
2219
2320
This class provides a common interface for API clients that communicate with a API
@@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]):
3229
a class that helps sending messages from a gRPC stream to
3330
a [Broadcast][frequenz.channels.Broadcast] channel.
3431
32+
Note:
33+
Because grpcio doesn't provide proper type hints, a hack is needed to have
34+
propepr async type hints for the stubs generated by protoc. When using
35+
`mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
36+
but in the `.pyi` file, so the type can be used to specify type hints, but
37+
**not** in any other context, as the class doesn't really exist for the Python
38+
interpreter. This include generics, and because of this, this class can't be
39+
even parametrized using the async class, so the instantiation of the stub can't
40+
be done in the base class.
41+
42+
Because of this, subclasses need to create the stubs by themselves, using the
43+
real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
44+
the async version of the stubs.
45+
46+
It is recommended to define a `stub` property that returns the async stub, so
47+
this hack is completely hidden from clients, even if they need to access the
48+
stub for more advanced uses.
49+
3550
Example:
3651
This example illustrates how to create a simple API client that connects to a
3752
gRPC server and calls a method on a stub.
3853
3954
```python
4055
from collections.abc import AsyncIterable
56+
from typing import cast
4157
from frequenz.client.base.client import BaseApiClient, call_stub_method
4258
from frequenz.client.base.streaming import GrpcStreamBroadcaster
4359
from frequenz.channels import Receiver
@@ -57,25 +73,42 @@ async def example_method(
5773
) -> ExampleResponse:
5874
...
5975
60-
def example_stream(self) -> AsyncIterable[ExampleResponse]:
76+
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
77+
...
78+
79+
class ExampleAsyncStub:
80+
async def example_method(
81+
self,
82+
request: ExampleRequest # pylint: disable=unused-argument
83+
) -> ExampleResponse:
84+
...
85+
86+
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
6187
...
6288
# End of generated classes
6389
6490
class ExampleResponseWrapper:
65-
def __init__(self, response: ExampleResponse):
91+
def __init__(self, response: ExampleResponse) -> None:
6692
self.transformed_value = f"{response.float_value:.2f}"
6793
68-
class MyApiClient(BaseApiClient[ExampleStub]):
69-
def __init__(self, server_url: str, *, connect: bool = True):
70-
super().__init__(
71-
server_url, ExampleStub, connect=connect
94+
class MyApiClient(BaseApiClient):
95+
def __init__(self, server_url: str, *, connect: bool = True) -> None:
96+
super().__init__(server_url, connect=connect)
97+
self._stub = cast(
98+
ExampleAsyncStub, ExampleStub(self.channel)
7299
)
73100
self._broadcaster = GrpcStreamBroadcaster(
74101
"stream",
75102
lambda: self.stub.example_stream(ExampleRequest()),
76103
ExampleResponseWrapper,
77104
)
78105
106+
@property
107+
def stub(self) -> ExampleAsyncStub:
108+
if self._channel is None:
109+
raise ClientNotConnected(server_url=self.server_url, operation="stub")
110+
return self._stub
111+
79112
async def example_method(
80113
self, int_value: int, str_value: str
81114
) -> ExampleResponseWrapper:
@@ -114,7 +147,6 @@ async def main():
114147
def __init__(
115148
self,
116149
server_url: str,
117-
create_stub: Callable[[Channel], StubT],
118150
*,
119151
connect: bool = True,
120152
channel_defaults: ChannelOptions = ChannelOptions(),
@@ -123,7 +155,6 @@ def __init__(
123155
124156
Args:
125157
server_url: The URL of the server to connect to.
126-
create_stub: A function that creates a stub from a channel.
127158
connect: Whether to connect to the server as soon as a client instance is
128159
created. If `False`, the client will not connect to the server until
129160
[connect()][frequenz.client.base.client.BaseApiClient.connect] is
@@ -132,10 +163,8 @@ def __init__(
132163
the server URL.
133164
"""
134165
self._server_url: str = server_url
135-
self._create_stub: Callable[[Channel], StubT] = create_stub
136166
self._channel_defaults: ChannelOptions = channel_defaults
137167
self._channel: Channel | None = None
138-
self._stub: StubT | None = None
139168
if connect:
140169
self.connect(server_url)
141170

@@ -165,22 +194,6 @@ def channel_defaults(self) -> ChannelOptions:
165194
"""The default options for the gRPC channel."""
166195
return self._channel_defaults
167196

168-
@property
169-
def stub(self) -> StubT:
170-
"""The underlying gRPC stub.
171-
172-
Warning:
173-
This stub is provided as a last resort for advanced users. It is not
174-
recommended to use this property directly unless you know what you are
175-
doing and you don't care about being tied to a specific gRPC library.
176-
177-
Raises:
178-
ClientNotConnected: If the client is not connected to the server.
179-
"""
180-
if self._stub is None:
181-
raise ClientNotConnected(server_url=self.server_url, operation="stub")
182-
return self._stub
183-
184197
@property
185198
def is_connected(self) -> bool:
186199
"""Whether the client is connected to the server."""
@@ -202,7 +215,6 @@ def connect(self, server_url: str | None = None) -> None:
202215
elif self.is_connected:
203216
return
204217
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
205-
self._stub = self._create_stub(self._channel)
206218

207219
async def disconnect(self) -> None:
208220
"""Disconnect from the server.
@@ -227,7 +239,6 @@ async def __aexit__(
227239
return None
228240
result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb)
229241
self._channel = None
230-
self._stub = None
231242
return result
232243

233244

@@ -240,7 +251,7 @@ async def __aexit__(
240251

241252
@overload
242253
async def call_stub_method(
243-
client: BaseApiClient[StubT],
254+
client: BaseApiClient,
244255
stub_method: Callable[[], Awaitable[StubOutT]],
245256
*,
246257
method_name: str | None = None,
@@ -250,7 +261,7 @@ async def call_stub_method(
250261

251262
@overload
252263
async def call_stub_method(
253-
client: BaseApiClient[StubT],
264+
client: BaseApiClient,
254265
stub_method: Callable[[], Awaitable[StubOutT]],
255266
*,
256267
method_name: str | None = None,
@@ -261,7 +272,7 @@ async def call_stub_method(
261272
# We need the `noqa: DOC503` because `pydoclint` can't figure out that
262273
# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
263274
async def call_stub_method( # noqa: DOC503
264-
client: BaseApiClient[StubT],
275+
client: BaseApiClient,
265276
stub_method: Callable[[], Awaitable[StubOutT]],
266277
*,
267278
method_name: str | None = None,

tests/test_client.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import pytest_mock
1313

1414
from frequenz.client.base.channel import ChannelOptions, SslOptions
15-
from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method
15+
from frequenz.client.base.client import BaseApiClient, call_stub_method
1616
from frequenz.client.base.exception import ClientNotConnected, UnknownError
1717

1818

1919
def _auto_connect_name(auto_connect: bool) -> str:
2020
return f"{auto_connect=}"
2121

2222

23-
def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
23+
def _assert_is_disconnected(client: BaseApiClient) -> None:
2424
"""Assert that the client is disconnected."""
2525
assert not client.is_connected
2626

@@ -30,17 +30,9 @@ def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
3030
assert exc.server_url == _DEFAULT_SERVER_URL
3131
assert exc.operation == "channel"
3232

33-
with pytest.raises(ClientNotConnected, match=r"") as exc_info:
34-
_ = client.stub
35-
exc = exc_info.value
36-
assert exc.server_url == _DEFAULT_SERVER_URL
37-
assert exc.operation == "stub"
38-
3933

4034
@dataclass(kw_only=True, frozen=True)
4135
class _ClientMocks:
42-
stub: mock.MagicMock
43-
create_stub: mock.MagicMock
4436
channel: mock.MagicMock
4537
parse_grpc_uri: mock.MagicMock
4638

@@ -54,10 +46,8 @@ def create_client_with_mocks(
5446
auto_connect: bool = True,
5547
server_url: str = _DEFAULT_SERVER_URL,
5648
channel_defaults: ChannelOptions | None = None,
57-
) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]:
49+
) -> tuple[BaseApiClient, _ClientMocks]:
5850
"""Create a BaseApiClient instance with mocks."""
59-
mock_stub = mock.MagicMock(name="stub")
60-
mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub)
6151
mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel)
6252
mock_parse_grpc_uri = mocker.patch(
6353
"frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel
@@ -67,13 +57,10 @@ def create_client_with_mocks(
6757
kwargs["channel_defaults"] = channel_defaults
6858
client = BaseApiClient(
6959
server_url=server_url,
70-
create_stub=mock_create_stub,
7160
connect=auto_connect,
7261
**kwargs,
7362
)
7463
return client, _ClientMocks(
75-
stub=mock_stub,
76-
create_stub=mock_create_stub,
7764
channel=mock_channel,
7865
parse_grpc_uri=mock_parse_grpc_uri,
7966
)
@@ -92,13 +79,10 @@ def test_base_api_client_init(
9279
client.server_url, ChannelOptions()
9380
)
9481
assert client.channel is mocks.channel
95-
assert client.stub is mocks.stub
9682
assert client.is_connected
97-
mocks.create_stub.assert_called_once_with(mocks.channel)
9883
else:
9984
_assert_is_disconnected(client)
10085
mocks.parse_grpc_uri.assert_not_called()
101-
mocks.create_stub.assert_not_called()
10286

10387

10488
def test_base_api_client_init_with_channel_defaults(
@@ -110,9 +94,7 @@ def test_base_api_client_init_with_channel_defaults(
11094
assert client.server_url == _DEFAULT_SERVER_URL
11195
mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults)
11296
assert client.channel is mocks.channel
113-
assert client.stub is mocks.stub
11497
assert client.is_connected
115-
mocks.create_stub.assert_called_once_with(mocks.channel)
11698

11799

118100
@pytest.mark.parametrize(
@@ -129,12 +111,10 @@ def test_base_api_client_connect(
129111
# We want to check only what happens when we call connect, so we reset the mocks
130112
# that were called during initialization
131113
mocks.parse_grpc_uri.reset_mock()
132-
mocks.create_stub.reset_mock()
133114

134115
client.connect(new_server_url)
135116

136117
assert client.channel is mocks.channel
137-
assert client.stub is mocks.stub
138118
assert client.is_connected
139119

140120
same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL
@@ -148,12 +128,10 @@ def test_base_api_client_connect(
148128
# reconnect
149129
if auto_connect and same_url:
150130
mocks.parse_grpc_uri.assert_not_called()
151-
mocks.create_stub.assert_not_called()
152131
else:
153132
mocks.parse_grpc_uri.assert_called_once_with(
154133
client.server_url, ChannelOptions()
155134
)
156-
mocks.create_stub.assert_called_once_with(mocks.channel)
157135

158136

159137
async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None:
@@ -177,23 +155,19 @@ async def test_base_api_client_async_context_manager(
177155
# We want to check only what happens when we enter the context manager, so we reset
178156
# the mocks that were called during initialization
179157
mocks.parse_grpc_uri.reset_mock()
180-
mocks.create_stub.reset_mock()
181158

182159
async with client:
183160
assert client.channel is mocks.channel
184-
assert client.stub is mocks.stub
185161
assert client.is_connected
186162
mocks.channel.__aexit__.assert_not_called()
187163
# If we were previously connected, the client should not reconnect when entering
188164
# the context manager
189165
if auto_connect:
190166
mocks.parse_grpc_uri.assert_not_called()
191-
mocks.create_stub.assert_not_called()
192167
else:
193168
mocks.parse_grpc_uri.assert_called_once_with(
194169
client.server_url, ChannelOptions()
195170
)
196-
mocks.create_stub.assert_called_once_with(mocks.channel)
197171

198172
mocks.channel.__aexit__.assert_called_once_with(None, None, None)
199173
assert client.server_url == _DEFAULT_SERVER_URL

0 commit comments

Comments
 (0)