77from collections .abc import AsyncIterator , Iterable , Iterator
88from dataclasses import dataclass
99from datetime import datetime , timedelta , timezone
10- from typing import cast
11-
12- import grpc .aio as grpcaio
10+ from typing import Any , AsyncIterable , cast
1311
1412# pylint: disable=no-name-in-module
1513from frequenz .api .common .v1 .microgrid .microgrid_pb2 import (
3937)
4038from frequenz .api .reporting .v1 .reporting_pb2 import TimeFilter as PBTimeFilter
4139from frequenz .api .reporting .v1 .reporting_pb2_grpc import ReportingStub
40+ from frequenz .client .base .channel import ChannelOptions
4241from frequenz .client .base .client import BaseApiClient
4342from frequenz .client .base .exception import ClientNotConnected
43+ from frequenz .client .base .streaming import GrpcStreamBroadcaster
4444from frequenz .client .common .metric import Metric
4545from google .protobuf .timestamp_pb2 import Timestamp as PBTimestamp
4646
@@ -177,14 +177,29 @@ def sample(self) -> MetricSample:
177177class ReportingApiClient (BaseApiClient [ReportingStub ]):
178178 """A client for the Reporting service."""
179179
180- def __init__ (self , server_url : str , key : str | None = None ) -> None :
180+ def __init__ (
181+ self ,
182+ server_url : str ,
183+ key : str | None = None ,
184+ connect : bool = True ,
185+ channel_defaults : ChannelOptions = ChannelOptions (), # default options
186+ ) -> None :
181187 """Create a new Reporting client.
182188
183189 Args:
184190 server_url: The URL of the Reporting service.
185191 key: The API key for the authorization.
192+ connect: Whether to connect to the server immediately.
193+ channel_defaults: The default channel options.
186194 """
187- super ().__init__ (server_url , ReportingStub )
195+ super ().__init__ (
196+ server_url ,
197+ ReportingStub ,
198+ connect = connect ,
199+ channel_defaults = channel_defaults ,
200+ )
201+
202+ self ._broadcasters : dict [int , GrpcStreamBroadcaster [Any , Any ]] = {}
188203
189204 self ._metadata = (("key" , key ),) if key else ()
190205
@@ -294,10 +309,7 @@ async def _list_microgrid_components_data_batch(
294309 include_states : bool = False ,
295310 include_bounds : bool = False ,
296311 ) -> AsyncIterator [ComponentsDataBatch ]:
297- """Iterate over the component data batches in the stream.
298-
299- Note: This does not yet support aggregating the data. It
300- also does not yet support fetching bound and state data.
312+ """Iterate over the component data batches in the stream using GrpcStreamBroadcaster.
301313
302314 Args:
303315 microgrid_components: A list of tuples of microgrid IDs and component IDs.
@@ -367,21 +379,33 @@ def dt2ts(dt: datetime) -> PBTimestamp:
367379 filter = stream_filter ,
368380 )
369381
370- try :
371- stream = cast (
372- AsyncIterator [PBReceiveMicrogridComponentsDataStreamResponse ],
373- self .stub .ReceiveMicrogridComponentsDataStream (
374- request , metadata = self ._metadata
375- ),
382+ def transform_response (
383+ response : PBReceiveMicrogridComponentsDataStreamResponse ,
384+ ) -> ComponentsDataBatch :
385+ return ComponentsDataBatch (response )
386+
387+ async def stream_method () -> (
388+ AsyncIterable [PBReceiveMicrogridComponentsDataStreamResponse ]
389+ ):
390+ call_iterator = self .stub .ReceiveMicrogridComponentsDataStream (
391+ request , metadata = self ._metadata
376392 )
377- async for response in stream :
378- if not response :
379- break
380- yield ComponentsDataBatch (response )
393+ async for response in cast (
394+ AsyncIterable [PBReceiveMicrogridComponentsDataStreamResponse ],
395+ call_iterator ,
396+ ):
397+ yield response
398+
399+ broadcaster = GrpcStreamBroadcaster (
400+ stream_name = "microgrid-components-data-stream" ,
401+ stream_method = stream_method ,
402+ transform = transform_response ,
403+ retry_strategy = None ,
404+ )
381405
382- except grpcaio . AioRpcError as e :
383- print ( f"RPC failed: { e } " )
384- return
406+ receiver = broadcaster . new_receiver ()
407+ async for data in receiver :
408+ yield data
385409
386410 async def receive_aggregated_data (
387411 self ,
@@ -393,10 +417,9 @@ async def receive_aggregated_data(
393417 end : datetime | None ,
394418 resampling_period : timedelta ,
395419 ) -> AsyncIterator [MetricSample ]:
396- """Iterate over aggregated data for a single metric.
420+ """Iterate over aggregated data for a single metric using GrpcStreamBroadcaster .
397421
398422 For now this only supports a single metric and aggregation formula.
399-
400423 Args:
401424 microgrid_id: The microgrid ID.
402425 metric: The metric name.
@@ -442,18 +465,26 @@ def dt2ts(dt: datetime) -> PBTimestamp:
442465 filter = stream_filter ,
443466 )
444467
445- try :
446- stream = cast (
447- AsyncIterator [ PBAggregatedStreamResponse ],
448- self . stub . ReceiveAggregatedMicrogridComponentsDataStream (
449- request , metadata = self ._metadata
450- ),
468+ def transform_response ( response : PBAggregatedStreamResponse ) -> MetricSample :
469+ return AggregatedMetric ( response ). sample ()
470+
471+ async def stream_method () -> AsyncIterable [ PBAggregatedStreamResponse ]:
472+ call_iterator = self .stub . ReceiveAggregatedMicrogridComponentsDataStream (
473+ request , metadata = self . _metadata
451474 )
452- async for response in stream :
453- if not response :
454- break
455- yield AggregatedMetric (response ).sample ()
456-
457- except grpcaio .AioRpcError as e :
458- print (f"RPC failed: { e } " )
459- return
475+
476+ async for response in cast (
477+ AsyncIterable [PBAggregatedStreamResponse ], call_iterator
478+ ):
479+ yield response
480+
481+ broadcaster = GrpcStreamBroadcaster (
482+ stream_name = "aggregated-microgrid-data-stream" ,
483+ stream_method = stream_method ,
484+ transform = transform_response ,
485+ retry_strategy = None ,
486+ )
487+
488+ receiver = broadcaster .new_receiver ()
489+ async for data in receiver :
490+ yield data
0 commit comments