Skip to content

Commit 82267a9

Browse files
Introduce GrpcStreamBroadcaster to enable keep_alive options (#154)
After a sensitivity analysis, the default `ChannelOptions()` of the base client are used (with the option to change them), as even with different configurations of `KeepAliveOptions()` of `interval` and `timeout`, the client sometimes has to retry.
2 parents 46bb5fd + 09d5ca0 commit 82267a9

File tree

3 files changed

+89
-41
lines changed

3 files changed

+89
-41
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
## New Features
1212

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
13+
* Introduced 'GrpcStreamBroadcaster' from the base client to enable keep-alive options for gRPC streaming.
14+
* The 'ChannelOptions' are currently set to the base client default, but can be change as an input.
1415

1516
## Bug Fixes
1617

src/frequenz/client/reporting/_client.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from collections.abc import AsyncIterator, Iterable, Iterator
88
from dataclasses import dataclass
99
from datetime import datetime, timedelta, timezone
10-
from typing import cast
11-
12-
import grpc.aio as grpcaio
10+
from typing import Any, AsyncIterable, cast
1311

1412
# pylint: disable=no-name-in-module
1513
from frequenz.api.common.v1.microgrid.microgrid_pb2 import (
@@ -39,8 +37,10 @@
3937
)
4038
from frequenz.api.reporting.v1.reporting_pb2 import TimeFilter as PBTimeFilter
4139
from frequenz.api.reporting.v1.reporting_pb2_grpc import ReportingStub
40+
from frequenz.client.base.channel import ChannelOptions
4241
from frequenz.client.base.client import BaseApiClient
4342
from frequenz.client.base.exception import ClientNotConnected
43+
from frequenz.client.base.streaming import GrpcStreamBroadcaster
4444
from frequenz.client.common.metric import Metric
4545
from google.protobuf.timestamp_pb2 import Timestamp as PBTimestamp
4646

@@ -177,14 +177,29 @@ def sample(self) -> MetricSample:
177177
class ReportingApiClient(BaseApiClient[ReportingStub]):
178178
"""A client for the Reporting service."""
179179

180-
def __init__(self, server_url: str, key: str | None = None) -> None:
180+
def __init__(
181+
self,
182+
server_url: str,
183+
key: str | None = None,
184+
connect: bool = True,
185+
channel_defaults: ChannelOptions = ChannelOptions(), # default options
186+
) -> None:
181187
"""Create a new Reporting client.
182188
183189
Args:
184190
server_url: The URL of the Reporting service.
185191
key: The API key for the authorization.
192+
connect: Whether to connect to the server immediately.
193+
channel_defaults: The default channel options.
186194
"""
187-
super().__init__(server_url, ReportingStub)
195+
super().__init__(
196+
server_url,
197+
ReportingStub,
198+
connect=connect,
199+
channel_defaults=channel_defaults,
200+
)
201+
202+
self._broadcasters: dict[int, GrpcStreamBroadcaster[Any, Any]] = {}
188203

189204
self._metadata = (("key", key),) if key else ()
190205

@@ -294,10 +309,7 @@ async def _list_microgrid_components_data_batch(
294309
include_states: bool = False,
295310
include_bounds: bool = False,
296311
) -> AsyncIterator[ComponentsDataBatch]:
297-
"""Iterate over the component data batches in the stream.
298-
299-
Note: This does not yet support aggregating the data. It
300-
also does not yet support fetching bound and state data.
312+
"""Iterate over the component data batches in the stream using GrpcStreamBroadcaster.
301313
302314
Args:
303315
microgrid_components: A list of tuples of microgrid IDs and component IDs.
@@ -367,21 +379,33 @@ def dt2ts(dt: datetime) -> PBTimestamp:
367379
filter=stream_filter,
368380
)
369381

370-
try:
371-
stream = cast(
372-
AsyncIterator[PBReceiveMicrogridComponentsDataStreamResponse],
373-
self.stub.ReceiveMicrogridComponentsDataStream(
374-
request, metadata=self._metadata
375-
),
382+
def transform_response(
383+
response: PBReceiveMicrogridComponentsDataStreamResponse,
384+
) -> ComponentsDataBatch:
385+
return ComponentsDataBatch(response)
386+
387+
async def stream_method() -> (
388+
AsyncIterable[PBReceiveMicrogridComponentsDataStreamResponse]
389+
):
390+
call_iterator = self.stub.ReceiveMicrogridComponentsDataStream(
391+
request, metadata=self._metadata
376392
)
377-
async for response in stream:
378-
if not response:
379-
break
380-
yield ComponentsDataBatch(response)
393+
async for response in cast(
394+
AsyncIterable[PBReceiveMicrogridComponentsDataStreamResponse],
395+
call_iterator,
396+
):
397+
yield response
398+
399+
broadcaster = GrpcStreamBroadcaster(
400+
stream_name="microgrid-components-data-stream",
401+
stream_method=stream_method,
402+
transform=transform_response,
403+
retry_strategy=None,
404+
)
381405

382-
except grpcaio.AioRpcError as e:
383-
print(f"RPC failed: {e}")
384-
return
406+
receiver = broadcaster.new_receiver()
407+
async for data in receiver:
408+
yield data
385409

386410
async def receive_aggregated_data(
387411
self,
@@ -393,10 +417,9 @@ async def receive_aggregated_data(
393417
end: datetime | None,
394418
resampling_period: timedelta,
395419
) -> AsyncIterator[MetricSample]:
396-
"""Iterate over aggregated data for a single metric.
420+
"""Iterate over aggregated data for a single metric using GrpcStreamBroadcaster.
397421
398422
For now this only supports a single metric and aggregation formula.
399-
400423
Args:
401424
microgrid_id: The microgrid ID.
402425
metric: The metric name.
@@ -442,18 +465,26 @@ def dt2ts(dt: datetime) -> PBTimestamp:
442465
filter=stream_filter,
443466
)
444467

445-
try:
446-
stream = cast(
447-
AsyncIterator[PBAggregatedStreamResponse],
448-
self.stub.ReceiveAggregatedMicrogridComponentsDataStream(
449-
request, metadata=self._metadata
450-
),
468+
def transform_response(response: PBAggregatedStreamResponse) -> MetricSample:
469+
return AggregatedMetric(response).sample()
470+
471+
async def stream_method() -> AsyncIterable[PBAggregatedStreamResponse]:
472+
call_iterator = self.stub.ReceiveAggregatedMicrogridComponentsDataStream(
473+
request, metadata=self._metadata
451474
)
452-
async for response in stream:
453-
if not response:
454-
break
455-
yield AggregatedMetric(response).sample()
456-
457-
except grpcaio.AioRpcError as e:
458-
print(f"RPC failed: {e}")
459-
return
475+
476+
async for response in cast(
477+
AsyncIterable[PBAggregatedStreamResponse], call_iterator
478+
):
479+
yield response
480+
481+
broadcaster = GrpcStreamBroadcaster(
482+
stream_name="aggregated-microgrid-data-stream",
483+
stream_method=stream_method,
484+
transform=transform_response,
485+
retry_strategy=None,
486+
)
487+
488+
receiver = broadcaster.new_receiver()
489+
async for data in receiver:
490+
yield data

tests/test_client_reporting.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
from frequenz.api.reporting.v1.reporting_pb2_grpc import ReportingStub
9+
from frequenz.client.base.channel import ChannelOptions
910
from frequenz.client.base.client import BaseApiClient
1011

1112
from frequenz.client.reporting import ReportingApiClient
@@ -15,9 +16,24 @@
1516
@pytest.mark.asyncio
1617
async def test_client_initialization() -> None:
1718
"""Test that the client initializes the BaseApiClient."""
19+
# Parameters for the ReportingApiClient initialization
20+
server_url = "gprc://localhost:50051"
21+
key = "some-api-key"
22+
connect = True
23+
channel_defaults = ChannelOptions()
24+
1825
with patch.object(BaseApiClient, "__init__", return_value=None) as mock_base_init:
19-
client = ReportingApiClient("gprc://localhost:50051") # noqa: F841
20-
mock_base_init.assert_called_once_with("gprc://localhost:50051", ReportingStub)
26+
client = ReportingApiClient(
27+
server_url, key=key, connect=connect, channel_defaults=channel_defaults
28+
) # noqa: F841
29+
mock_base_init.assert_called_once_with(
30+
server_url,
31+
ReportingStub,
32+
connect=connect,
33+
channel_defaults=channel_defaults,
34+
)
35+
36+
assert client._metadata == (("key", key),) # pylint: disable=W0212
2137

2238

2339
def test_components_data_batch_is_empty_true() -> None:

0 commit comments

Comments
 (0)