diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index f80f12e..9bde965 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,8 @@ ## New Features - +* Introduced 'GrpcStreamBroadcaster' from the base client to enable keep-alive options for gRPC streaming. +* The 'ChannelOptions' are currently set to the base client default, but can be change as an input. ## Bug Fixes diff --git a/src/frequenz/client/reporting/_client.py b/src/frequenz/client/reporting/_client.py index 97b23c3..3bc8737 100644 --- a/src/frequenz/client/reporting/_client.py +++ b/src/frequenz/client/reporting/_client.py @@ -7,9 +7,7 @@ from collections.abc import AsyncIterator, Iterable, Iterator from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import cast - -import grpc.aio as grpcaio +from typing import Any, AsyncIterable, cast # pylint: disable=no-name-in-module from frequenz.api.common.v1.microgrid.microgrid_pb2 import ( @@ -39,8 +37,10 @@ ) from frequenz.api.reporting.v1.reporting_pb2 import TimeFilter as PBTimeFilter from frequenz.api.reporting.v1.reporting_pb2_grpc import ReportingStub +from frequenz.client.base.channel import ChannelOptions from frequenz.client.base.client import BaseApiClient from frequenz.client.base.exception import ClientNotConnected +from frequenz.client.base.streaming import GrpcStreamBroadcaster from frequenz.client.common.metric import Metric from google.protobuf.timestamp_pb2 import Timestamp as PBTimestamp @@ -177,14 +177,29 @@ def sample(self) -> MetricSample: class ReportingApiClient(BaseApiClient[ReportingStub]): """A client for the Reporting service.""" - def __init__(self, server_url: str, key: str | None = None) -> None: + def __init__( + self, + server_url: str, + key: str | None = None, + connect: bool = True, + channel_defaults: ChannelOptions = ChannelOptions(), # default options + ) -> None: """Create a new Reporting client. Args: server_url: The URL of the Reporting service. key: The API key for the authorization. + connect: Whether to connect to the server immediately. + channel_defaults: The default channel options. """ - super().__init__(server_url, ReportingStub) + super().__init__( + server_url, + ReportingStub, + connect=connect, + channel_defaults=channel_defaults, + ) + + self._broadcasters: dict[int, GrpcStreamBroadcaster[Any, Any]] = {} self._metadata = (("key", key),) if key else () @@ -294,10 +309,7 @@ async def _list_microgrid_components_data_batch( include_states: bool = False, include_bounds: bool = False, ) -> AsyncIterator[ComponentsDataBatch]: - """Iterate over the component data batches in the stream. - - Note: This does not yet support aggregating the data. It - also does not yet support fetching bound and state data. + """Iterate over the component data batches in the stream using GrpcStreamBroadcaster. Args: microgrid_components: A list of tuples of microgrid IDs and component IDs. @@ -367,21 +379,33 @@ def dt2ts(dt: datetime) -> PBTimestamp: filter=stream_filter, ) - try: - stream = cast( - AsyncIterator[PBReceiveMicrogridComponentsDataStreamResponse], - self.stub.ReceiveMicrogridComponentsDataStream( - request, metadata=self._metadata - ), + def transform_response( + response: PBReceiveMicrogridComponentsDataStreamResponse, + ) -> ComponentsDataBatch: + return ComponentsDataBatch(response) + + async def stream_method() -> ( + AsyncIterable[PBReceiveMicrogridComponentsDataStreamResponse] + ): + call_iterator = self.stub.ReceiveMicrogridComponentsDataStream( + request, metadata=self._metadata ) - async for response in stream: - if not response: - break - yield ComponentsDataBatch(response) + async for response in cast( + AsyncIterable[PBReceiveMicrogridComponentsDataStreamResponse], + call_iterator, + ): + yield response + + broadcaster = GrpcStreamBroadcaster( + stream_name="microgrid-components-data-stream", + stream_method=stream_method, + transform=transform_response, + retry_strategy=None, + ) - except grpcaio.AioRpcError as e: - print(f"RPC failed: {e}") - return + receiver = broadcaster.new_receiver() + async for data in receiver: + yield data async def receive_aggregated_data( self, @@ -393,10 +417,9 @@ async def receive_aggregated_data( end: datetime | None, resampling_period: timedelta, ) -> AsyncIterator[MetricSample]: - """Iterate over aggregated data for a single metric. + """Iterate over aggregated data for a single metric using GrpcStreamBroadcaster. For now this only supports a single metric and aggregation formula. - Args: microgrid_id: The microgrid ID. metric: The metric name. @@ -442,18 +465,26 @@ def dt2ts(dt: datetime) -> PBTimestamp: filter=stream_filter, ) - try: - stream = cast( - AsyncIterator[PBAggregatedStreamResponse], - self.stub.ReceiveAggregatedMicrogridComponentsDataStream( - request, metadata=self._metadata - ), + def transform_response(response: PBAggregatedStreamResponse) -> MetricSample: + return AggregatedMetric(response).sample() + + async def stream_method() -> AsyncIterable[PBAggregatedStreamResponse]: + call_iterator = self.stub.ReceiveAggregatedMicrogridComponentsDataStream( + request, metadata=self._metadata ) - async for response in stream: - if not response: - break - yield AggregatedMetric(response).sample() - - except grpcaio.AioRpcError as e: - print(f"RPC failed: {e}") - return + + async for response in cast( + AsyncIterable[PBAggregatedStreamResponse], call_iterator + ): + yield response + + broadcaster = GrpcStreamBroadcaster( + stream_name="aggregated-microgrid-data-stream", + stream_method=stream_method, + transform=transform_response, + retry_strategy=None, + ) + + receiver = broadcaster.new_receiver() + async for data in receiver: + yield data diff --git a/tests/test_client_reporting.py b/tests/test_client_reporting.py index 306482b..6dd58c9 100644 --- a/tests/test_client_reporting.py +++ b/tests/test_client_reporting.py @@ -6,6 +6,7 @@ import pytest from frequenz.api.reporting.v1.reporting_pb2_grpc import ReportingStub +from frequenz.client.base.channel import ChannelOptions from frequenz.client.base.client import BaseApiClient from frequenz.client.reporting import ReportingApiClient @@ -15,9 +16,24 @@ @pytest.mark.asyncio async def test_client_initialization() -> None: """Test that the client initializes the BaseApiClient.""" + # Parameters for the ReportingApiClient initialization + server_url = "gprc://localhost:50051" + key = "some-api-key" + connect = True + channel_defaults = ChannelOptions() + with patch.object(BaseApiClient, "__init__", return_value=None) as mock_base_init: - client = ReportingApiClient("gprc://localhost:50051") # noqa: F841 - mock_base_init.assert_called_once_with("gprc://localhost:50051", ReportingStub) + client = ReportingApiClient( + server_url, key=key, connect=connect, channel_defaults=channel_defaults + ) # noqa: F841 + mock_base_init.assert_called_once_with( + server_url, + ReportingStub, + connect=connect, + channel_defaults=channel_defaults, + ) + + assert client._metadata == (("key", key),) # pylint: disable=W0212 def test_components_data_batch_is_empty_true() -> None: