Skip to content

Commit 5414a6d

Browse files
authored
Remove generic type from BaseApiClient (#92)
The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints.
2 parents 9864776 + 091256b commit 5414a6d

File tree

3 files changed

+87
-65
lines changed

3 files changed

+87
-65
lines changed

RELEASE_NOTES.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@
1616

1717
- HTTP2 keep-alive is now enabled by default, with an interval of 60 seconds between pings, and a 20 second timeout for responses from the service. These values are configurable and may be updated based on specific requirements.
1818

19+
* The `BaseApiClient` class is not generic anymore, and doesn't take a function to create the stub. Instead, subclasses should create their own stub right after calling the parent constructor. This enables subclasses to cast the stub to the generated `XxxAsyncStub` class, which have proper `async` type hints. To convert you client:
20+
21+
```python
22+
# Old
23+
from my_service_pb2_grpc import MyServiceStub
24+
class MyApiClient(BaseApiClient[MyServiceStub]):
25+
def __init__(self, server_url: str, *, ...) -> None:
26+
super().__init__(server_url, MyServiceStub, ...)
27+
...
28+
29+
# New
30+
from typing import cast
31+
from my_service_pb2_grpc import MyServiceStub, MyServiceAsyncStub
32+
class MyApiClient(BaseApiClient):
33+
def __init__(self, server_url: str, *, ...) -> None:
34+
super().__init__(server_url, connect=connect)
35+
self._stub = cast(MyServiceAsyncStub, MyServiceStub(self.channel))
36+
...
37+
38+
@property
39+
def stub(self) -> MyServiceAsyncStub:
40+
if self._channel is None:
41+
raise ClientNotConnected(server_url=self.server_url, operation="stub")
42+
return self._stub
43+
```
44+
45+
After this, you should be able to remove a lot of `cast`s or `type: ignore` from the code when calling the stub `async` methods.
46+
1947
## New Features
2048

2149
- Added support for HTTP2 keep-alive.

src/frequenz/client/base/client.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
import abc
77
import inspect
88
from collections.abc import Awaitable, Callable
9-
from typing import Any, Generic, Self, TypeVar, overload
9+
from typing import Any, Self, TypeVar, overload
1010

1111
from grpc.aio import AioRpcError, Channel
1212

1313
from .channel import ChannelOptions, parse_grpc_uri
1414
from .exception import ApiClientError, ClientNotConnected
1515

16-
StubT = TypeVar("StubT")
17-
"""The type of the gRPC stub."""
1816

19-
20-
class BaseApiClient(abc.ABC, Generic[StubT]):
17+
class BaseApiClient(abc.ABC):
2118
"""A base class for API clients.
2219
2320
This class provides a common interface for API clients that communicate with a API
@@ -32,12 +29,31 @@ class BaseApiClient(abc.ABC, Generic[StubT]):
3229
a class that helps sending messages from a gRPC stream to
3330
a [Broadcast][frequenz.channels.Broadcast] channel.
3431
32+
Note:
33+
Because grpcio doesn't provide proper type hints, a hack is needed to have
34+
propepr async type hints for the stubs generated by protoc. When using
35+
`mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class
36+
but in the `.pyi` file, so the type can be used to specify type hints, but
37+
**not** in any other context, as the class doesn't really exist for the Python
38+
interpreter. This include generics, and because of this, this class can't be
39+
even parametrized using the async class, so the instantiation of the stub can't
40+
be done in the base class.
41+
42+
Because of this, subclasses need to create the stubs by themselves, using the
43+
real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use
44+
the async version of the stubs.
45+
46+
It is recommended to define a `stub` property that returns the async stub, so
47+
this hack is completely hidden from clients, even if they need to access the
48+
stub for more advanced uses.
49+
3550
Example:
3651
This example illustrates how to create a simple API client that connects to a
3752
gRPC server and calls a method on a stub.
3853
3954
```python
4055
from collections.abc import AsyncIterable
56+
from typing import cast
4157
from frequenz.client.base.client import BaseApiClient, call_stub_method
4258
from frequenz.client.base.streaming import GrpcStreamBroadcaster
4359
from frequenz.channels import Receiver
@@ -57,25 +73,51 @@ async def example_method(
5773
) -> ExampleResponse:
5874
...
5975
60-
def example_stream(self) -> AsyncIterable[ExampleResponse]:
76+
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
77+
...
78+
79+
class ExampleAsyncStub:
80+
async def example_method(
81+
self,
82+
request: ExampleRequest # pylint: disable=unused-argument
83+
) -> ExampleResponse:
84+
...
85+
86+
def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]:
6187
...
6288
# End of generated classes
6389
6490
class ExampleResponseWrapper:
65-
def __init__(self, response: ExampleResponse):
91+
def __init__(self, response: ExampleResponse) -> None:
6692
self.transformed_value = f"{response.float_value:.2f}"
6793
68-
class MyApiClient(BaseApiClient[ExampleStub]):
69-
def __init__(self, server_url: str, *, connect: bool = True):
70-
super().__init__(
71-
server_url, ExampleStub, connect=connect
94+
# Change defaults as needed
95+
DEFAULT_CHANNEL_OPTIONS = ChannelOptions()
96+
97+
class MyApiClient(BaseApiClient):
98+
def __init__(
99+
self,
100+
server_url: str,
101+
*,
102+
connect: bool = True,
103+
channel_defaults: ChannelOptions = DEFAULT_CHANNEL_OPTIONS,
104+
) -> None:
105+
super().__init__(server_url, connect=connect, channel_defaults=channel_defaults)
106+
self._stub = cast(
107+
ExampleAsyncStub, ExampleStub(self.channel)
72108
)
73109
self._broadcaster = GrpcStreamBroadcaster(
74110
"stream",
75111
lambda: self.stub.example_stream(ExampleRequest()),
76112
ExampleResponseWrapper,
77113
)
78114
115+
@property
116+
def stub(self) -> ExampleAsyncStub:
117+
if self._channel is None:
118+
raise ClientNotConnected(server_url=self.server_url, operation="stub")
119+
return self._stub
120+
79121
async def example_method(
80122
self, int_value: int, str_value: str
81123
) -> ExampleResponseWrapper:
@@ -114,7 +156,6 @@ async def main():
114156
def __init__(
115157
self,
116158
server_url: str,
117-
create_stub: Callable[[Channel], StubT],
118159
*,
119160
connect: bool = True,
120161
channel_defaults: ChannelOptions = ChannelOptions(),
@@ -123,7 +164,6 @@ def __init__(
123164
124165
Args:
125166
server_url: The URL of the server to connect to.
126-
create_stub: A function that creates a stub from a channel.
127167
connect: Whether to connect to the server as soon as a client instance is
128168
created. If `False`, the client will not connect to the server until
129169
[connect()][frequenz.client.base.client.BaseApiClient.connect] is
@@ -132,10 +172,8 @@ def __init__(
132172
the server URL.
133173
"""
134174
self._server_url: str = server_url
135-
self._create_stub: Callable[[Channel], StubT] = create_stub
136175
self._channel_defaults: ChannelOptions = channel_defaults
137176
self._channel: Channel | None = None
138-
self._stub: StubT | None = None
139177
if connect:
140178
self.connect(server_url)
141179

@@ -165,22 +203,6 @@ def channel_defaults(self) -> ChannelOptions:
165203
"""The default options for the gRPC channel."""
166204
return self._channel_defaults
167205

168-
@property
169-
def stub(self) -> StubT:
170-
"""The underlying gRPC stub.
171-
172-
Warning:
173-
This stub is provided as a last resort for advanced users. It is not
174-
recommended to use this property directly unless you know what you are
175-
doing and you don't care about being tied to a specific gRPC library.
176-
177-
Raises:
178-
ClientNotConnected: If the client is not connected to the server.
179-
"""
180-
if self._stub is None:
181-
raise ClientNotConnected(server_url=self.server_url, operation="stub")
182-
return self._stub
183-
184206
@property
185207
def is_connected(self) -> bool:
186208
"""Whether the client is connected to the server."""
@@ -202,7 +224,6 @@ def connect(self, server_url: str | None = None) -> None:
202224
elif self.is_connected:
203225
return
204226
self._channel = parse_grpc_uri(self._server_url, self._channel_defaults)
205-
self._stub = self._create_stub(self._channel)
206227

207228
async def disconnect(self) -> None:
208229
"""Disconnect from the server.
@@ -227,7 +248,6 @@ async def __aexit__(
227248
return None
228249
result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb)
229250
self._channel = None
230-
self._stub = None
231251
return result
232252

233253

@@ -240,7 +260,7 @@ async def __aexit__(
240260

241261
@overload
242262
async def call_stub_method(
243-
client: BaseApiClient[StubT],
263+
client: BaseApiClient,
244264
stub_method: Callable[[], Awaitable[StubOutT]],
245265
*,
246266
method_name: str | None = None,
@@ -250,7 +270,7 @@ async def call_stub_method(
250270

251271
@overload
252272
async def call_stub_method(
253-
client: BaseApiClient[StubT],
273+
client: BaseApiClient,
254274
stub_method: Callable[[], Awaitable[StubOutT]],
255275
*,
256276
method_name: str | None = None,
@@ -261,7 +281,7 @@ async def call_stub_method(
261281
# We need the `noqa: DOC503` because `pydoclint` can't figure out that
262282
# `ApiClientError.from_grpc_error()` returns a `GrpcError` instance.
263283
async def call_stub_method( # noqa: DOC503
264-
client: BaseApiClient[StubT],
284+
client: BaseApiClient,
265285
stub_method: Callable[[], Awaitable[StubOutT]],
266286
*,
267287
method_name: str | None = None,

tests/test_client.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import pytest_mock
1313

1414
from frequenz.client.base.channel import ChannelOptions, SslOptions
15-
from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method
15+
from frequenz.client.base.client import BaseApiClient, call_stub_method
1616
from frequenz.client.base.exception import ClientNotConnected, UnknownError
1717

1818

1919
def _auto_connect_name(auto_connect: bool) -> str:
2020
return f"{auto_connect=}"
2121

2222

23-
def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
23+
def _assert_is_disconnected(client: BaseApiClient) -> None:
2424
"""Assert that the client is disconnected."""
2525
assert not client.is_connected
2626

@@ -30,17 +30,9 @@ def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None:
3030
assert exc.server_url == _DEFAULT_SERVER_URL
3131
assert exc.operation == "channel"
3232

33-
with pytest.raises(ClientNotConnected, match=r"") as exc_info:
34-
_ = client.stub
35-
exc = exc_info.value
36-
assert exc.server_url == _DEFAULT_SERVER_URL
37-
assert exc.operation == "stub"
38-
3933

4034
@dataclass(kw_only=True, frozen=True)
4135
class _ClientMocks:
42-
stub: mock.MagicMock
43-
create_stub: mock.MagicMock
4436
channel: mock.MagicMock
4537
parse_grpc_uri: mock.MagicMock
4638

@@ -54,10 +46,8 @@ def create_client_with_mocks(
5446
auto_connect: bool = True,
5547
server_url: str = _DEFAULT_SERVER_URL,
5648
channel_defaults: ChannelOptions | None = None,
57-
) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]:
49+
) -> tuple[BaseApiClient, _ClientMocks]:
5850
"""Create a BaseApiClient instance with mocks."""
59-
mock_stub = mock.MagicMock(name="stub")
60-
mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub)
6151
mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel)
6252
mock_parse_grpc_uri = mocker.patch(
6353
"frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel
@@ -67,13 +57,10 @@ def create_client_with_mocks(
6757
kwargs["channel_defaults"] = channel_defaults
6858
client = BaseApiClient(
6959
server_url=server_url,
70-
create_stub=mock_create_stub,
7160
connect=auto_connect,
7261
**kwargs,
7362
)
7463
return client, _ClientMocks(
75-
stub=mock_stub,
76-
create_stub=mock_create_stub,
7764
channel=mock_channel,
7865
parse_grpc_uri=mock_parse_grpc_uri,
7966
)
@@ -92,13 +79,10 @@ def test_base_api_client_init(
9279
client.server_url, ChannelOptions()
9380
)
9481
assert client.channel is mocks.channel
95-
assert client.stub is mocks.stub
9682
assert client.is_connected
97-
mocks.create_stub.assert_called_once_with(mocks.channel)
9883
else:
9984
_assert_is_disconnected(client)
10085
mocks.parse_grpc_uri.assert_not_called()
101-
mocks.create_stub.assert_not_called()
10286

10387

10488
def test_base_api_client_init_with_channel_defaults(
@@ -110,9 +94,7 @@ def test_base_api_client_init_with_channel_defaults(
11094
assert client.server_url == _DEFAULT_SERVER_URL
11195
mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults)
11296
assert client.channel is mocks.channel
113-
assert client.stub is mocks.stub
11497
assert client.is_connected
115-
mocks.create_stub.assert_called_once_with(mocks.channel)
11698

11799

118100
@pytest.mark.parametrize(
@@ -129,12 +111,10 @@ def test_base_api_client_connect(
129111
# We want to check only what happens when we call connect, so we reset the mocks
130112
# that were called during initialization
131113
mocks.parse_grpc_uri.reset_mock()
132-
mocks.create_stub.reset_mock()
133114

134115
client.connect(new_server_url)
135116

136117
assert client.channel is mocks.channel
137-
assert client.stub is mocks.stub
138118
assert client.is_connected
139119

140120
same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL
@@ -148,12 +128,10 @@ def test_base_api_client_connect(
148128
# reconnect
149129
if auto_connect and same_url:
150130
mocks.parse_grpc_uri.assert_not_called()
151-
mocks.create_stub.assert_not_called()
152131
else:
153132
mocks.parse_grpc_uri.assert_called_once_with(
154133
client.server_url, ChannelOptions()
155134
)
156-
mocks.create_stub.assert_called_once_with(mocks.channel)
157135

158136

159137
async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None:
@@ -177,23 +155,19 @@ async def test_base_api_client_async_context_manager(
177155
# We want to check only what happens when we enter the context manager, so we reset
178156
# the mocks that were called during initialization
179157
mocks.parse_grpc_uri.reset_mock()
180-
mocks.create_stub.reset_mock()
181158

182159
async with client:
183160
assert client.channel is mocks.channel
184-
assert client.stub is mocks.stub
185161
assert client.is_connected
186162
mocks.channel.__aexit__.assert_not_called()
187163
# If we were previously connected, the client should not reconnect when entering
188164
# the context manager
189165
if auto_connect:
190166
mocks.parse_grpc_uri.assert_not_called()
191-
mocks.create_stub.assert_not_called()
192167
else:
193168
mocks.parse_grpc_uri.assert_called_once_with(
194169
client.server_url, ChannelOptions()
195170
)
196-
mocks.create_stub.assert_called_once_with(mocks.channel)
197171

198172
mocks.channel.__aexit__.assert_called_once_with(None, None, None)
199173
assert client.server_url == _DEFAULT_SERVER_URL

0 commit comments

Comments
 (0)