diff --git a/src/frequenz/client/microgrid/_client.py b/src/frequenz/client/microgrid/_client.py index 55de9a9b..6ef8cc2a 100644 --- a/src/frequenz/client/microgrid/_client.py +++ b/src/frequenz/client/microgrid/_client.py @@ -5,17 +5,15 @@ import asyncio import logging -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Set +from collections.abc import Callable, Iterable, Set from dataclasses import replace -from typing import Any, TypeVar, cast +from typing import Any, TypeVar -import grpc.aio from frequenz.api.common import components_pb2, metrics_pb2 from frequenz.api.microgrid import microgrid_pb2, microgrid_pb2_grpc from frequenz.channels import Receiver from frequenz.client.base import channel, client, retry, streaming from google.protobuf.empty_pb2 import Empty -from google.protobuf.timestamp_pb2 import Timestamp from ._component import ( Component, @@ -91,10 +89,13 @@ def __init__( connect=connect, channel_defaults=channel_defaults, ) + self._async_stub: microgrid_pb2_grpc.MicrogridAsyncStub = self.stub # type: ignore self._broadcasters: dict[int, streaming.GrpcStreamBroadcaster[Any, Any]] = {} self._retry_strategy = retry_strategy - async def components(self) -> Iterable[Component]: + async def components( # noqa: DOC502 (raises ApiClientError indirectly) + self, + ) -> Iterable[Component]: """Fetch all the components present in the microgrid. Returns: @@ -105,22 +106,14 @@ async def components(self) -> Iterable[Component]: most likely a subclass of [GrpcError][frequenz.client.microgrid.GrpcError]. """ - try: - # grpc.aio is missing types and mypy thinks this is not awaitable, - # but it is - component_list = await cast( - Awaitable[microgrid_pb2.ComponentList], - self.stub.ListComponents( - microgrid_pb2.ComponentFilter(), - timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), - ), - ) - except grpc.aio.AioRpcError as grpc_error: - raise ApiClientError.from_grpc_error( - server_url=self._server_url, - operation="ListComponents", - grpc_error=grpc_error, - ) from grpc_error + component_list = await client.call_stub_method( + self, + lambda: self._async_stub.ListComponents( + microgrid_pb2.ComponentFilter(), + timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), + ), + method_name="ListComponents", + ) components_only = filter( lambda c: c.category @@ -150,14 +143,15 @@ async def metadata(self) -> Metadata: """ microgrid_metadata: microgrid_pb2.MicrogridMetadata | None = None try: - microgrid_metadata = await cast( - Awaitable[microgrid_pb2.MicrogridMetadata], - self.stub.GetMicrogridMetadata( + microgrid_metadata = await client.call_stub_method( + self, + lambda: self._async_stub.GetMicrogridMetadata( Empty(), timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), ), + method_name="GetMicrogridMetadata", ) - except grpc.aio.AioRpcError: + except ApiClientError: _logger.exception("The microgrid metadata is not available.") if not microgrid_metadata: @@ -172,7 +166,7 @@ async def metadata(self) -> Metadata: return Metadata(microgrid_id=microgrid_metadata.microgrid_id, location=location) - async def connections( + async def connections( # noqa: DOC502 (raises ApiClientError indirectly) self, starts: Set[int] = frozenset(), ends: Set[int] = frozenset(), @@ -194,25 +188,18 @@ async def connections( [GrpcError][frequenz.client.microgrid.GrpcError]. """ connection_filter = microgrid_pb2.ConnectionFilter(starts=starts, ends=ends) - try: - valid_components, all_connections = await asyncio.gather( - self.components(), - # grpc.aio is missing types and mypy thinks this is not - # awaitable, but it is - cast( - Awaitable[microgrid_pb2.ConnectionList], - self.stub.ListConnections( - connection_filter, - timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), - ), + valid_components, all_connections = await asyncio.gather( + self.components(), + client.call_stub_method( + self, + lambda: self._async_stub.ListConnections( + connection_filter, + timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), ), - ) - except grpc.aio.AioRpcError as grpc_error: - raise ApiClientError.from_grpc_error( - server_url=self._server_url, - operation="ListConnections", - grpc_error=grpc_error, - ) from grpc_error + method_name="ListConnections", + ), + ) + # Filter out the components filtered in `components` method. # id=0 is an exception indicating grid component. valid_ids = {c.component_id for c in valid_components} @@ -261,15 +248,10 @@ async def _new_component_data_receiver( if broadcaster is None: broadcaster = streaming.GrpcStreamBroadcaster( f"raw-component-data-{component_id}", - # We need to cast here because grpc says StreamComponentData is - # a grpc.CallIterator[microgrid_pb2.ComponentData] which is not an - # AsyncIterator, but it is a grpc.aio.UnaryStreamCall[..., - # microgrid_pb2.ComponentData], which it is. - lambda: cast( - AsyncIterator[microgrid_pb2.ComponentData], - self.stub.StreamComponentData( + lambda: aiter( + self._async_stub.StreamComponentData( microgrid_pb2.ComponentIdParam(id=component_id) - ), + ) ), transform, retry_strategy=self._retry_strategy, @@ -405,7 +387,9 @@ async def ev_charger_data( # noqa: DOC502 (ValueError is raised indirectly by _ maxsize=maxsize, ) - async def set_power(self, component_id: int, power_w: float) -> None: + async def set_power( # noqa: DOC502 (raises ApiClientError indirectly) + self, component_id: int, power_w: float + ) -> None: """Send request to the Microgrid to set power for component. If power > 0, then component will be charged with this power. @@ -422,22 +406,16 @@ async def set_power(self, component_id: int, power_w: float) -> None: most likely a subclass of [GrpcError][frequenz.client.microgrid.GrpcError]. """ - try: - await cast( - Awaitable[Empty], - self.stub.SetPowerActive( - microgrid_pb2.SetPowerActiveParam( - component_id=component_id, power=power_w - ), - timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), + await client.call_stub_method( + self, + lambda: self._async_stub.SetPowerActive( + microgrid_pb2.SetPowerActiveParam( + component_id=component_id, power=power_w ), - ) - except grpc.aio.AioRpcError as grpc_error: - raise ApiClientError.from_grpc_error( - server_url=self._server_url, - operation="SetPowerActive", - grpc_error=grpc_error, - ) from grpc_error + timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), + ), + method_name="SetPowerActive", + ) async def set_bounds( self, @@ -467,21 +445,15 @@ async def set_bounds( target_metric = ( microgrid_pb2.SetBoundsParam.TargetMetric.TARGET_METRIC_POWER_ACTIVE ) - try: - await cast( - Awaitable[Timestamp], - self.stub.AddInclusionBounds( - microgrid_pb2.SetBoundsParam( - component_id=component_id, - target_metric=target_metric, - bounds=metrics_pb2.Bounds(lower=lower, upper=upper), - ), - timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), + await client.call_stub_method( + self, + lambda: self._async_stub.AddInclusionBounds( + microgrid_pb2.SetBoundsParam( + component_id=component_id, + target_metric=target_metric, + bounds=metrics_pb2.Bounds(lower=lower, upper=upper), ), - ) - except grpc.aio.AioRpcError as grpc_error: - raise ApiClientError.from_grpc_error( - server_url=self._server_url, - operation="AddInclusionBounds", - grpc_error=grpc_error, - ) from grpc_error + timeout=int(DEFAULT_GRPC_CALL_TIMEOUT), + ), + method_name="AddInclusionBounds", + ) diff --git a/tests/test_client.py b/tests/test_client.py index 6ca18475..5466f79b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -47,6 +47,7 @@ def __init__(self, *, retry_strategy: retry.Strategy | None = None) -> None: super().__init__("grpc://mock_host:1234", retry_strategy=retry_strategy) self.mock_stub = mock_stub self._stub = mock_stub # pylint: disable=protected-access + self._async_stub = mock_stub # pylint: disable=protected-access async def test_components() -> None: