From e0216eaacf058f02be1f7236037e38720a8b8f48 Mon Sep 17 00:00:00 2001 From: "Mathias L. Baumann" Date: Wed, 4 Jun 2025 13:17:01 +0200 Subject: [PATCH] Streaming: Make events optional Signed-off-by: Mathias L. Baumann --- RELEASE_NOTES.md | 6 +- src/frequenz/client/base/streaming.py | 197 ++++++++++++++---- .../streaming/test_grpc_stream_broadcaster.py | 29 ++- 3 files changed, 178 insertions(+), 54 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 9a57d26..ac5a24e 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,10 +10,10 @@ ## New Features -* The streaming client now also sends state change events out. Usage example: +* The streaming client, when using `new_receiver(include_events=True)`, will now return a receiver that yields stream notification events, such as `StreamStarted`, `StreamRetrying`, and `StreamFatalError`. This allows you to monitor the state of the stream: - ```python - recv = streamer.new_receiver() + ```python + recv = streamer.new_receiver(include_events=True) for msg in recv: match msg: diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index 8c5533d..61610ad 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta -from typing import AsyncIterable, Generic, TypeAlias, TypeVar +from typing import AsyncIterable, Generic, Literal, TypeAlias, TypeVar, overload import grpc.aio @@ -58,6 +58,7 @@ class StreamFatalError: """Type alias for the events that can be sent over the stream.""" +# pylint: disable-next=too-many-instance-attributes class GrpcStreamBroadcaster(Generic[InputT, OutputT]): """Helper class to handle grpc streaming methods. @@ -69,36 +70,86 @@ class GrpcStreamBroadcaster(Generic[InputT, OutputT]): with the `stop` method. New receivers can be created with the `new_receiver` method, which will receive the streamed messages. - Additionally to the transformed messages, the broadcaster will also send - state change messages indicating whether the stream is connecting, - connected, or disconnected. These messages can be used to monitor the - state of the stream. + If `include_events=True` is passed to `new_receiver`, the receiver will + also get state change messages (`StreamStarted`, `StreamRetrying`, + `StreamFatalError`) indicating the state of the stream. Example: ```python - from frequenz.client.base import GrpcStreamBroadcaster + from frequenz.client.base import ( + GrpcStreamBroadcaster, + StreamFatalError, + StreamRetrying, + StreamStarted, + ) + from frequenz.channels import Receiver # Assuming Receiver is available + + # Dummy async iterable for demonstration + async def async_range(fail_after: int = -1) -> AsyncIterable[int]: + for i in range(10): + if fail_after != -1 and i >= fail_after: + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details="Simulated error" + ) + yield i + await asyncio.sleep(0.1) + + async def main(): + streamer = GrpcStreamBroadcaster( + stream_name="example_stream", + stream_method=lambda: async_range(fail_after=3), + transform=lambda msg: msg * 2, # transform messages + retry_on_exhausted_stream=False, + ) - def async_range() -> AsyncIterable[int]: - yield from range(10) + # Receiver for data only + data_recv: Receiver[int] = streamer.new_receiver() - streamer = GrpcStreamBroadcaster( - stream_name="example_stream", - stream_method=async_range, - transform=lambda msg: msg, - ) + # Receiver for data and events + mixed_recv: Receiver[int | StreamEvent] = streamer.new_receiver( + include_events=True + ) - recv = streamer.new_receiver() - - for msg in recv: - match msg: - case StreamStarted(): - print("Stream started") - case StreamRetrying(delay, error): - print(f"Stream stopped and will retry in {delay}: {error or 'closed'}") - case StreamFatalError(error): - print(f"Stream will stop because of a fatal error: {error}") - case int() as output: - print(f"Received message: {output}") + async def consume_mixed(): + async for msg in mixed_recv: + match msg: + case StreamStarted(): + print("Mixed: Stream started") + case StreamRetrying(delay, error): + print( + "Mixed: Stream retrying in " + + f"{delay.total_seconds():.1f}s: {error or 'closed'}" + ) + case StreamFatalError(error): + print(f"Mixed: Stream fatal error: {error}") + break # Stop consuming on fatal error + case int() as output: + print(f"Mixed: Received data: {output}") + if isinstance(msg, StreamFatalError): + break + print("Mixed: Consumer finished") + + + async def consume_data(): + async for data_msg in data_recv: + print(f"DataOnly: Received data: {data_msg}") + print("DataOnly: Consumer finished") + + mixed_consumer_task = asyncio.create_task(consume_mixed()) + data_consumer_task = asyncio.create_task(consume_data()) + + await asyncio.sleep(5) # Let it run for a bit + print("Stopping streamer...") + await streamer.stop() + await mixed_consumer_task + await data_consumer_task + print("Streamer stopped.") + + if __name__ == "__main__": + asyncio.run(main()) ``` """ @@ -130,27 +181,77 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument ) self._retry_on_exhausted_stream = retry_on_exhausted_stream - self._channel: channels.Broadcast[StreamEvent | OutputT] = channels.Broadcast( - name=f"GrpcStreamBroadcaster-{stream_name}" + # Channel for transformed data messages (OutputT) + self._data_channel: channels.Broadcast[OutputT] = channels.Broadcast( + name=f"GrpcStreamBroadcaster-{stream_name}-Data" ) + + # Channel for stream events (StreamEvent), created on demand + self._event_channel: channels.Broadcast[StreamEvent] | None = None + self._event_sender: channels.Sender[StreamEvent] | None = None self._task = asyncio.create_task(self._run()) + @overload + def new_receiver( + self, + maxsize: int = 50, + warn_on_overflow: bool = True, + *, + include_events: Literal[False] = False, + ) -> channels.Receiver[OutputT]: ... + + @overload + def new_receiver( + self, + maxsize: int = 50, + warn_on_overflow: bool = True, + *, + include_events: Literal[True], + ) -> channels.Receiver[StreamEvent | OutputT]: ... + def new_receiver( - self, maxsize: int = 50, warn_on_overflow: bool = True - ) -> channels.Receiver[StreamEvent | OutputT]: + self, + maxsize: int = 50, + warn_on_overflow: bool = True, + *, + include_events: bool = False, + ) -> channels.Receiver[OutputT] | channels.Receiver[StreamEvent | OutputT]: """Create a new receiver for the stream. Args: - maxsize: The maximum number of messages to buffer. - warn_on_overflow: Whether to log a warning when the receiver's + maxsize: The maximum number of messages to buffer in underlying receivers. + warn_on_overflow: Whether to log a warning when a receiver's buffer is full and a message is dropped. + include_events: Whether to include stream events (e.g. StreamStarted, + StreamRetrying, StreamFatalError) in the receiver. If `False` (default), + only transformed data messages will be received. Returns: - A new receiver. + A new receiver. If `include_events` is True, the receiver will yield + both `OutputT` and `StreamEvent` types. Otherwise, only `OutputT`. """ - return self._channel.new_receiver( + if not include_events: + return self._data_channel.new_receiver( + limit=maxsize, warn_on_overflow=warn_on_overflow + ) + + if self._event_channel is None: + _logger.debug( + "%s: First request for events, creating event channel.", + self._stream_name, + ) + self._event_channel = channels.Broadcast[StreamEvent]( + name=f"GrpcStreamBroadcaster-{self._stream_name}-Events" + ) + self._event_sender = self._event_channel.new_sender() + + data_rx = self._data_channel.new_receiver( limit=maxsize, warn_on_overflow=warn_on_overflow ) + event_rx = self._event_channel.new_receiver( + limit=maxsize, warn_on_overflow=warn_on_overflow + ) + return channels.merge(data_rx, event_rx) @property def is_running(self) -> bool: @@ -171,20 +272,25 @@ async def stop(self) -> None: await self._task except asyncio.CancelledError: pass - await self._channel.close() + await self._data_channel.close() + if self._event_channel is not None: + await self._event_channel.close() async def _run(self) -> None: """Run the streaming helper.""" - sender = self._channel.new_sender() + data_sender = self._data_channel.new_sender() while True: error: Exception | None = None _logger.info("%s: starting to stream", self._stream_name) try: call = self._stream_method() - await sender.send(StreamStarted()) + + if self._event_sender: + await self._event_sender.send(StreamStarted()) + async for msg in call: - await sender.send(self._transform(msg)) + await data_sender.send(self._transform(msg)) except grpc.aio.AioRpcError as err: error = err @@ -192,7 +298,9 @@ async def _run(self) -> None: _logger.info( "%s: connection closed, stream exhausted", self._stream_name ) - await self._channel.close() + await self._data_channel.close() + if self._event_channel is not None: + await self._event_channel.close() break interval = self._retry_strategy.next_interval() @@ -204,9 +312,11 @@ async def _run(self) -> None: self._retry_strategy.get_progress(), error_str, ) - if error is not None: - await sender.send(StreamFatalError(error)) - await self._channel.close() + if error is not None and self._event_sender: + await self._event_sender.send(StreamFatalError(error)) + await self._data_channel.close() + if self._event_channel is not None: + await self._event_channel.close() break _logger.warning( "%s: connection ended, retrying %s in %0.3f seconds. %s.", @@ -216,5 +326,8 @@ async def _run(self) -> None: error_str, ) - await sender.send(StreamRetrying(timedelta(seconds=interval), error)) + if self._event_sender: + await self._event_sender.send( + StreamRetrying(timedelta(seconds=interval), error) + ) await asyncio.sleep(interval) diff --git a/tests/streaming/test_grpc_stream_broadcaster.py b/tests/streaming/test_grpc_stream_broadcaster.py index b96fadf..6c6241f 100644 --- a/tests/streaming/test_grpc_stream_broadcaster.py +++ b/tests/streaming/test_grpc_stream_broadcaster.py @@ -8,6 +8,7 @@ from collections.abc import AsyncIterator from contextlib import AsyncExitStack from datetime import timedelta +from typing import Literal from unittest import mock import grpc @@ -260,9 +261,11 @@ async def test_streaming_error( # pylint: disable=too-many-arguments ] +@pytest.mark.parametrize("include_events", [True, False]) async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name caplog: pytest.LogCaptureFixture, + include_events: Literal[True] | Literal[False], ) -> None: """Test retry logic when next_interval returns 0.""" caplog.set_level(logging.WARNING) @@ -279,14 +282,17 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments ) items: list[str] = [] + events: list[StreamEvent] = [] async with AsyncExitStack() as stack: stack.push_async_callback(helper.stop) - receiver = helper.new_receiver() + receiver = helper.new_receiver(include_events=include_events) receiver_ready_event.set() - items, _ = await _split_message(receiver) + items, events = await _split_message(receiver) assert not items + assert bool(events) == include_events + assert mock_retry.next_interval.mock_calls == [mock.call(), mock.call()] assert caplog.record_tuples == [ ( @@ -304,8 +310,10 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments ] +@pytest.mark.parametrize("include_events", [True, False]) async def test_messages_on_retry( receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name + include_events: Literal[True] | Literal[False], ) -> None: """Test that messages are sent on retry.""" # We need to use a specific instance for all the test here because 2 errors created @@ -328,7 +336,7 @@ async def test_messages_on_retry( async with AsyncExitStack() as stack: stack.push_async_callback(helper.stop) - receiver = helper.new_receiver() + receiver = helper.new_receiver(include_events=include_events) receiver_ready_event.set() items, events = await _split_message(receiver) @@ -338,9 +346,12 @@ async def test_messages_on_retry( "transformed_0", "transformed_1", ] - assert events == [ - StreamStarted(), - StreamRetrying(timedelta(seconds=0.0), error), - StreamStarted(), - StreamFatalError(error), - ] + if include_events: + assert events == [ + StreamStarted(), + StreamRetrying(timedelta(seconds=0.0), error), + StreamStarted(), + StreamFatalError(error), + ] + else: + assert events == []