From 9744ff1404dbf51f61ab9cb37c9899ae19af8578 Mon Sep 17 00:00:00 2001 From: Leandro Lucarella Date: Mon, 18 Nov 2024 14:14:20 +0100 Subject: [PATCH] Revert "Remove generic type from `BaseApiClient`" This change was intended to make building a client simpler, but it ended up being more complicated than expected, using the async stub is more complicated than expected, as it lives only in the `.pyi` file and can't be used in any other context than type hints. For example the new approach didn't worked well with delaying the connection of the client. To handle that correctly, more work is needed by subclasses. This commit reverts back to making the `BaseApiClient` class generic and it instantiates the stub internally as before. To get proper async type hints, users now only need to write the `stub` property themselves, and use the appropriate async stub type hint there. This reverts commit 035a7949a9bea7237911f64450a4200ba5d38eb7. Signed-off-by: Leandro Lucarella --- RELEASE_NOTES.md | 69 +++++++++++++++++++++++------ src/frequenz/client/base/client.py | 71 ++++++++++++++++-------------- tests/test_client.py | 32 ++++++++++++-- 3 files changed, 122 insertions(+), 50 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 4bd60ba..c3340ba 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,17 +1,60 @@ # Frequenz Client Base Library Release Notes -## Summary - - - ## Upgrading - - -## New Features - - - -## Bug Fixes - - +The `BaseApiClient` class is generic again. There was too many issues with the new approach, so it was rolled back. + +- If you are upgrading from v0.7.x, you should be able to roll back your changes with the upgrade and just keep the new `stub` property. + + ```python + # Old + from __future__ import annotations + import my_service_pb2_grpc + class MyApiClient(BaseApiClient): + def __init__(self, server_url: str, *, ...) -> None: + super().__init__(server_url, ...) + stub = my_service_pb2_grpc.MyServiceStub(self.channel) + self._stub: my_service_pb2_grpc.MyServiceAsyncStub = stub # type: ignore + ... + + @property + def stub(self) -> my_service_pb2_grpc.MyServiceAsyncStub: + if self.channel is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + return self._stub + + # New + from __future__ import annotations + import my_service_pb2_grpc + from my_service_pb2_grpc import MyServiceStub + class MyApiClient(BaseApiClient[MyServiceStub]): + def __init__(self, server_url: str, *, ...) -> None: + super().__init__(server_url, MyServiceStub, ...) + ... + + @property + def stub(self) -> my_service_pb2_grpc.MyServiceAsyncStub: + """The gRPC stub for the API.""" + if self.channel is None or self._stub is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + # This type: ignore is needed because we need to cast the sync stub to + # the async stub, but we can't use cast because the async stub doesn't + # actually exists to the eyes of the interpreter, it only exists for the + # type-checker, so it can only be used for type hints. + return self._stub # type: ignore + ``` + +- If you are upgrading from v0.6.x, you should only need to add the `stub` property to your client class and then use that property instead of `_stub` in your code. + + ```python + @property + def stub(self) -> my_service_pb2_grpc.MyServiceAsyncStub: + """The gRPC stub for the API.""" + if self.channel is None or self._stub is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + # This type: ignore is needed because we need to cast the sync stub to + # the async stub, but we can't use cast because the async stub doesn't + # actually exists to the eyes of the interpreter, it only exists for the + # type-checker, so it can only be used for type hints. + return self._stub # type: ignore + ``` diff --git a/src/frequenz/client/base/client.py b/src/frequenz/client/base/client.py index 423974a..d2c2908 100644 --- a/src/frequenz/client/base/client.py +++ b/src/frequenz/client/base/client.py @@ -6,21 +6,34 @@ import abc import inspect from collections.abc import Awaitable, Callable -from typing import Any, Self, TypeVar, overload +from typing import Any, Generic, Self, TypeVar, overload from grpc.aio import AioRpcError, Channel from .channel import ChannelOptions, parse_grpc_uri from .exception import ApiClientError, ClientNotConnected +StubT = TypeVar("StubT") +"""The type of the gRPC stub.""" -class BaseApiClient(abc.ABC): + +class BaseApiClient(abc.ABC, Generic[StubT]): """A base class for API clients. This class provides a common interface for API clients that communicate with a API server. It is designed to be subclassed by specific API clients that provide a more specific interface. + Note: + It is recommended to add a `stub` property to the subclass that returns the gRPC + stub to use but using the *async stub* type instead of the *sync stub* type. + This is because the gRPC library provides async stubs that have proper async + type hints, but they only live in `.pyi` files, so they can be used in a very + limited way (only as type hints). Because of this, a `type: ignore` comment is + needed to cast the sync stub to the async stub. + + Please see the example below for a recommended way to implement this property. + Some extra tools are provided to make it easier to write API clients: - [call_stub_method()][frequenz.client.base.client.call_stub_method] is a function @@ -29,31 +42,12 @@ class BaseApiClient(abc.ABC): a class that helps sending messages from a gRPC stream to a [Broadcast][frequenz.channels.Broadcast] channel. - Note: - Because grpcio doesn't provide proper type hints, a hack is needed to have - propepr async type hints for the stubs generated by protoc. When using - `mypy-protobuf`, a `XxxAsyncStub` class is generated for each `XxxStub` class - but in the `.pyi` file, so the type can be used to specify type hints, but - **not** in any other context, as the class doesn't really exist for the Python - interpreter. This include generics, and because of this, this class can't be - even parametrized using the async class, so the instantiation of the stub can't - be done in the base class. - - Because of this, subclasses need to create the stubs by themselves, using the - real stub class and casting it to the `XxxAsyncStub` class, so `mypy` can use - the async version of the stubs. - - It is recommended to define a `stub` property that returns the async stub, so - this hack is completely hidden from clients, even if they need to access the - stub for more advanced uses. - Example: This example illustrates how to create a simple API client that connects to a gRPC server and calls a method on a stub. ```python from collections.abc import AsyncIterable - from typing import cast from frequenz.client.base.client import BaseApiClient, call_stub_method from frequenz.client.base.streaming import GrpcStreamBroadcaster from frequenz.channels import Receiver @@ -67,13 +61,13 @@ class ExampleResponse: float_value: float class ExampleStub: - async def example_method( + def example_method( self, request: ExampleRequest # pylint: disable=unused-argument ) -> ExampleResponse: ... - def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]: + def example_stream(self) -> AsyncIterable[ExampleResponse]: ... class ExampleAsyncStub: @@ -83,18 +77,18 @@ async def example_method( ) -> ExampleResponse: ... - def example_stream(self, _: ExampleRequest) -> AsyncIterable[ExampleResponse]: + def example_stream(self) -> AsyncIterable[ExampleResponse]: ... # End of generated classes class ExampleResponseWrapper: - def __init__(self, response: ExampleResponse) -> None: + def __init__(self, response: ExampleResponse): self.transformed_value = f"{response.float_value:.2f}" # Change defaults as needed DEFAULT_CHANNEL_OPTIONS = ChannelOptions() - class MyApiClient(BaseApiClient): + class MyApiClient(BaseApiClient[ExampleStub]): def __init__( self, server_url: str, @@ -102,9 +96,8 @@ def __init__( connect: bool = True, channel_defaults: ChannelOptions = DEFAULT_CHANNEL_OPTIONS, ) -> None: - super().__init__(server_url, connect=connect, channel_defaults=channel_defaults) - self._stub = cast( - ExampleAsyncStub, ExampleStub(self.channel) + super().__init__( + server_url, ExampleStub, connect=connect, channel_defaults=channel_defaults ) self._broadcaster = GrpcStreamBroadcaster( "stream", @@ -114,9 +107,13 @@ def __init__( @property def stub(self) -> ExampleAsyncStub: - if self._channel is None: + if self.channel is None or self._stub is None: raise ClientNotConnected(server_url=self.server_url, operation="stub") - return self._stub + # This type: ignore is needed because we need to cast the sync stub to + # the async stub, but we can't use cast because the async stub doesn't + # actually exists to the eyes of the interpreter, it only exists for the + # type-checker, so it can only be used for type hints. + return self._stub # type: ignore async def example_method( self, int_value: int, str_value: str @@ -156,6 +153,7 @@ async def main(): def __init__( self, server_url: str, + create_stub: Callable[[Channel], StubT], *, connect: bool = True, channel_defaults: ChannelOptions = ChannelOptions(), @@ -164,6 +162,7 @@ def __init__( Args: server_url: The URL of the server to connect to. + create_stub: A function that creates a stub from a channel. connect: Whether to connect to the server as soon as a client instance is created. If `False`, the client will not connect to the server until [connect()][frequenz.client.base.client.BaseApiClient.connect] is @@ -172,8 +171,10 @@ def __init__( the server URL. """ self._server_url: str = server_url + self._create_stub: Callable[[Channel], StubT] = create_stub self._channel_defaults: ChannelOptions = channel_defaults self._channel: Channel | None = None + self._stub: StubT | None = None if connect: self.connect(server_url) @@ -224,6 +225,7 @@ def connect(self, server_url: str | None = None) -> None: elif self.is_connected: return self._channel = parse_grpc_uri(self._server_url, self._channel_defaults) + self._stub = self._create_stub(self._channel) async def disconnect(self) -> None: """Disconnect from the server. @@ -248,6 +250,7 @@ async def __aexit__( return None result = await self._channel.__aexit__(_exc_type, _exc_val, _exc_tb) self._channel = None + self._stub = None return result @@ -260,7 +263,7 @@ async def __aexit__( @overload async def call_stub_method( - client: BaseApiClient, + client: BaseApiClient[StubT], stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, @@ -270,7 +273,7 @@ async def call_stub_method( @overload async def call_stub_method( - client: BaseApiClient, + client: BaseApiClient[StubT], stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, @@ -281,7 +284,7 @@ async def call_stub_method( # We need the `noqa: DOC503` because `pydoclint` can't figure out that # `ApiClientError.from_grpc_error()` returns a `GrpcError` instance. async def call_stub_method( # noqa: DOC503 - client: BaseApiClient, + client: BaseApiClient[StubT], stub_method: Callable[[], Awaitable[StubOutT]], *, method_name: str | None = None, diff --git a/tests/test_client.py b/tests/test_client.py index dcf3d5f..f0fa99b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,7 +12,7 @@ import pytest_mock from frequenz.client.base.channel import ChannelOptions, SslOptions -from frequenz.client.base.client import BaseApiClient, call_stub_method +from frequenz.client.base.client import BaseApiClient, StubT, call_stub_method from frequenz.client.base.exception import ClientNotConnected, UnknownError @@ -20,7 +20,7 @@ def _auto_connect_name(auto_connect: bool) -> str: return f"{auto_connect=}" -def _assert_is_disconnected(client: BaseApiClient) -> None: +def _assert_is_disconnected(client: BaseApiClient[StubT]) -> None: """Assert that the client is disconnected.""" assert not client.is_connected @@ -30,9 +30,17 @@ def _assert_is_disconnected(client: BaseApiClient) -> None: assert exc.server_url == _DEFAULT_SERVER_URL assert exc.operation == "channel" + with pytest.raises(ClientNotConnected, match=r"") as exc_info: + _ = client.channel + exc = exc_info.value + assert exc.server_url == _DEFAULT_SERVER_URL + assert exc.operation == "channel" + @dataclass(kw_only=True, frozen=True) class _ClientMocks: + stub: mock.MagicMock + create_stub: mock.MagicMock channel: mock.MagicMock parse_grpc_uri: mock.MagicMock @@ -46,8 +54,10 @@ def create_client_with_mocks( auto_connect: bool = True, server_url: str = _DEFAULT_SERVER_URL, channel_defaults: ChannelOptions | None = None, -) -> tuple[BaseApiClient, _ClientMocks]: +) -> tuple[BaseApiClient[mock.MagicMock], _ClientMocks]: """Create a BaseApiClient instance with mocks.""" + mock_stub = mock.MagicMock(name="stub") + mock_create_stub = mock.MagicMock(name="create_stub", return_value=mock_stub) mock_channel = mock.MagicMock(name="channel", spec=grpc.aio.Channel) mock_parse_grpc_uri = mocker.patch( "frequenz.client.base.client.parse_grpc_uri", return_value=mock_channel @@ -57,10 +67,13 @@ def create_client_with_mocks( kwargs["channel_defaults"] = channel_defaults client = BaseApiClient( server_url=server_url, + create_stub=mock_create_stub, connect=auto_connect, **kwargs, ) return client, _ClientMocks( + stub=mock_stub, + create_stub=mock_create_stub, channel=mock_channel, parse_grpc_uri=mock_parse_grpc_uri, ) @@ -79,10 +92,13 @@ def test_base_api_client_init( client.server_url, ChannelOptions() ) assert client.channel is mocks.channel + assert client._stub is mocks.stub # pylint: disable=protected-access assert client.is_connected + mocks.create_stub.assert_called_once_with(mocks.channel) else: _assert_is_disconnected(client) mocks.parse_grpc_uri.assert_not_called() + mocks.create_stub.assert_not_called() def test_base_api_client_init_with_channel_defaults( @@ -94,7 +110,9 @@ def test_base_api_client_init_with_channel_defaults( assert client.server_url == _DEFAULT_SERVER_URL mocks.parse_grpc_uri.assert_called_once_with(client.server_url, channel_defaults) assert client.channel is mocks.channel + assert client._stub is mocks.stub # pylint: disable=protected-access assert client.is_connected + mocks.create_stub.assert_called_once_with(mocks.channel) @pytest.mark.parametrize( @@ -111,10 +129,12 @@ def test_base_api_client_connect( # We want to check only what happens when we call connect, so we reset the mocks # that were called during initialization mocks.parse_grpc_uri.reset_mock() + mocks.create_stub.reset_mock() client.connect(new_server_url) assert client.channel is mocks.channel + assert client._stub is mocks.stub # pylint: disable=protected-access assert client.is_connected same_url = new_server_url is None or new_server_url == _DEFAULT_SERVER_URL @@ -128,10 +148,12 @@ def test_base_api_client_connect( # reconnect if auto_connect and same_url: mocks.parse_grpc_uri.assert_not_called() + mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( client.server_url, ChannelOptions() ) + mocks.create_stub.assert_called_once_with(mocks.channel) async def test_base_api_client_disconnect(mocker: pytest_mock.MockFixture) -> None: @@ -155,19 +177,23 @@ async def test_base_api_client_async_context_manager( # We want to check only what happens when we enter the context manager, so we reset # the mocks that were called during initialization mocks.parse_grpc_uri.reset_mock() + mocks.create_stub.reset_mock() async with client: assert client.channel is mocks.channel + assert client._stub is mocks.stub # pylint: disable=protected-access assert client.is_connected mocks.channel.__aexit__.assert_not_called() # If we were previously connected, the client should not reconnect when entering # the context manager if auto_connect: mocks.parse_grpc_uri.assert_not_called() + mocks.create_stub.assert_not_called() else: mocks.parse_grpc_uri.assert_called_once_with( client.server_url, ChannelOptions() ) + mocks.create_stub.assert_called_once_with(mocks.channel) mocks.channel.__aexit__.assert_called_once_with(None, None, None) assert client.server_url == _DEFAULT_SERVER_URL