diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 4bd60ba..046d496 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,19 @@ ## New Features - +* The streaming client now also sends state change events out. Usage example: +```python + recv = streamer.new_receiver() + + for msg in recv: + match msg: + case StreamStartedEvent(): + print("Stream started") + case StreamStoppedEvent() as event: + print(f"Stream stopped, reason {event.exception}, retry in {event.retry_time}") + case int() as output: + print(f"Received message: {output}") +``` ## Bug Fixes diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index f690ee1..2e95bcd 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -6,7 +6,9 @@ import asyncio import logging from collections.abc import Callable -from typing import AsyncIterable, Generic, TypeVar +from dataclasses import dataclass +from datetime import timedelta +from typing import AsyncIterable, Generic, TypeAlias, TypeVar import grpc.aio @@ -24,8 +26,70 @@ """The output type of the stream.""" +@dataclass(frozen=True, kw_only=True) +class StreamStartedEvent: + """Event indicating that the stream has started.""" + + +@dataclass(frozen=True, kw_only=True) +class StreamStoppedEvent: + """Event indicating that the stream has stopped.""" + + retry_time: timedelta | None = None + """Time to wait before retrying the stream, if applicable.""" + + exception: Exception | None = None + """The exception that caused the stream to stop, if any.""" + + +StreamEvent: TypeAlias = StreamStartedEvent | StreamStoppedEvent +"""Type alias for the events that can be sent over the stream.""" + + +# Ignore D412: "No blank lines allowed between a section header and its content" +# flake8: noqa: D412 class GrpcStreamBroadcaster(Generic[InputT, OutputT]): - """Helper class to handle grpc streaming methods.""" + """Helper class to handle grpc streaming methods. + + This class handles the grpc streaming methods, automatically reconnecting + when the connection is lost, and broadcasting the received messages to + multiple receivers. + + The stream is started when the class is initialized, and can be stopped + with the `stop` method. New receivers can be created with the + `new_receiver` method, which will receive the streamed messages. + + Additionally to the transformed messages, the broadcaster will also send + state change messages indicating whether the stream is connecting, + connected, or disconnected. These messages can be used to monitor the + state of the stream. + + Example: + + ```python + from frequenz.client.base import GrpcStreamBroadcaster + + def async_range() -> AsyncIterable[int]: + yield from range(10) + + streamer = GrpcStreamBroadcaster( + stream_name="example_stream", + stream_method=async_range, + transform=lambda msg: msg, + ) + + recv = streamer.new_receiver() + + for msg in recv: + match msg: + case StreamStartedEvent(): + print("Stream started") + case StreamStoppedEvent() as event: + print(f"Stream stopped, reason {event.exception}, retry in {event.retry_time}") + case int() as output: + print(f"Received message: {output}") + ``` + """ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, @@ -55,14 +119,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument ) self._retry_on_exhausted_stream = retry_on_exhausted_stream - self._channel: channels.Broadcast[OutputT] = channels.Broadcast( + self._channel: channels.Broadcast[StreamEvent | OutputT] = channels.Broadcast( name=f"GrpcStreamBroadcaster-{stream_name}" ) self._task = asyncio.create_task(self._run()) def new_receiver( self, maxsize: int = 50, warn_on_overflow: bool = True - ) -> channels.Receiver[OutputT]: + ) -> channels.Receiver[StreamEvent | OutputT]: """Create a new receiver for the stream. Args: @@ -107,16 +171,21 @@ async def _run(self) -> None: _logger.info("%s: starting to stream", self._stream_name) try: call = self._stream_method() + await sender.send(StreamStartedEvent()) async for msg in call: await sender.send(self._transform(msg)) except grpc.aio.AioRpcError as err: error = err - except Exception as err: # pylint: disable=broad-except - _logger.exception( - "%s: raise an unexpected exception", - self._stream_name, + + interval = self._retry_strategy.next_interval() + + await sender.send( + StreamStoppedEvent( + retry_time=timedelta(seconds=interval) if interval else None, + exception=error, ) - error = err + ) + if error is None and not self._retry_on_exhausted_stream: _logger.info( "%s: connection closed, stream exhausted", self._stream_name @@ -124,7 +193,6 @@ async def _run(self) -> None: await self._channel.close() break error_str = f"Error: {error}" if error else "Stream exhausted" - interval = self._retry_strategy.next_interval() if interval is None: _logger.error( "%s: connection ended, retry limit exceeded (%s), giving up. %s.", diff --git a/tests/streaming/test_grpc_stream_broadcaster.py b/tests/streaming/test_grpc_stream_broadcaster.py index 87001a9..a9a0bef 100644 --- a/tests/streaming/test_grpc_stream_broadcaster.py +++ b/tests/streaming/test_grpc_stream_broadcaster.py @@ -7,12 +7,19 @@ import logging from collections.abc import AsyncIterator from contextlib import AsyncExitStack +from datetime import timedelta from unittest import mock import grpc.aio import pytest +from frequenz.channels import Receiver from frequenz.client.base import retry, streaming +from frequenz.client.base.streaming import ( + StreamEvent, + StreamStartedEvent, + StreamStoppedEvent, +) def _transformer(x: int) -> str: @@ -36,6 +43,17 @@ def no_retry() -> mock.MagicMock: return mock_retry +def mock_error() -> grpc.aio.AioRpcError: + """Mock error for testing.""" + return grpc.aio.AioRpcError( + code=mock.MagicMock(name="mock grpc code"), + initial_metadata=mock.MagicMock(), + trailing_metadata=mock.MagicMock(), + details="mock details", + debug_error_string="mock debug_error_string", + ) + + @pytest.fixture async def ok_helper( no_retry: mock.MagicMock, # pylint: disable=redefined-outer-name @@ -62,6 +80,28 @@ async def asynciter(ready_event: asyncio.Event) -> AsyncIterator[int]: await helper.stop() +async def _split_message( + receiver: Receiver[StreamEvent | str], +) -> tuple[list[str], list[StreamEvent]]: + """Split the items received from the receiver into items and messages. + + Args: + receiver: The receiver to process. + + Returns: + A tuple containing a list of transformed items and a list of messages. + """ + items: list[str] = [] + events: list[StreamEvent] = [] + async for item in receiver: + match item: + case StreamStartedEvent() | StreamStoppedEvent() as item: + events.append(item) + case str(): + items.append(item) + return items, events + + class _ErroringAsyncIter(AsyncIterator[int]): """Async iterator that raises an error after a certain number of successes.""" @@ -93,11 +133,12 @@ async def test_streaming_success_retry_on_exhausted( """Test streaming success.""" caplog.set_level(logging.INFO) items: list[str] = [] + events: list[StreamEvent] = [] async with asyncio.timeout(1): receiver = ok_helper.new_receiver() receiver_ready_event.set() - async for item in receiver: - items.append(item) + items, events = await _split_message(receiver) + no_retry.next_interval.assert_called_once_with() assert items == [ "transformed_0", @@ -106,6 +147,10 @@ async def test_streaming_success_retry_on_exhausted( "transformed_3", "transformed_4", ] + assert events == [ + StreamStoppedEvent(exception=None, retry_time=None), + ] + assert caplog.record_tuples == [ ( "frequenz.client.base.streaming", @@ -128,14 +173,14 @@ async def test_streaming_success( """Test streaming success.""" caplog.set_level(logging.INFO) items: list[str] = [] + events: list[StreamEvent] = [] + async with asyncio.timeout(1): receiver = ok_helper.new_receiver() receiver_ready_event.set() - async for item in receiver: - items.append(item) - assert ( - no_retry.next_interval.call_count == 0 - ), "next_interval should not be called when streaming is successful" + items, events = await _split_message(receiver) + + no_retry.next_interval.assert_called_once_with() assert items == [ "transformed_0", @@ -144,6 +189,9 @@ async def test_streaming_success( "transformed_3", "transformed_4", ] + assert events == [ + StreamStoppedEvent(exception=None, retry_time=None), + ] assert caplog.record_tuples == [ ( "frequenz.client.base.streaming", @@ -173,13 +221,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments """Test streaming errors.""" caplog.set_level(logging.INFO) - error = grpc.aio.AioRpcError( - code=_NamedMagicMock(name="mock grpc code"), - initial_metadata=mock.MagicMock(), - trailing_metadata=mock.MagicMock(), - details="mock details", - debug_error_string="mock debug_error_string", - ) + error = mock_error() helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", @@ -196,8 +238,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments receiver = helper.new_receiver() receiver_ready_event.set() - async for item in receiver: - items.append(item) + items, _ = await _split_message(receiver) no_retry.next_interval.assert_called_once_with() assert items == [f"transformed_{i}" for i in range(successes)] @@ -211,12 +252,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments "frequenz.client.base.streaming", logging.ERROR, "test_helper: connection ended, retry limit exceeded (mock progress), " - "giving up. Error: " - ".", + f"giving up. Error: {error}.", ), ( "frequenz.client.base.streaming", @@ -232,13 +268,7 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments ) -> None: """Test retry logic when next_interval returns 0.""" caplog.set_level(logging.WARNING) - error = grpc.aio.AioRpcError( - code=_NamedMagicMock(name="mock grpcio code"), - initial_metadata=mock.MagicMock(), - trailing_metadata=mock.MagicMock(), - details="mock details", - debug_error_string="mock debug_error_string", - ) + error = mock_error() mock_retry = mock.MagicMock(spec=retry.Strategy) mock_retry.next_interval.side_effect = [0, None] mock_retry.copy.return_value = mock_retry @@ -256,29 +286,62 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments receiver = helper.new_receiver() receiver_ready_event.set() - async for item in receiver: - items.append(item) + items, _ = await _split_message(receiver) assert not items assert mock_retry.next_interval.mock_calls == [mock.call(), mock.call()] - expected_error_str = ( - "" - ) assert caplog.record_tuples == [ ( "frequenz.client.base.streaming", logging.WARNING, "test_helper: connection ended, retrying mock progress in 0.000 " - f"seconds. Error: {expected_error_str}.", + f"seconds. Error: {error}.", ), ( "frequenz.client.base.streaming", logging.ERROR, "test_helper: connection ended, retry limit exceeded (mock progress), " - f"giving up. Error: {expected_error_str}.", + f"giving up. Error: {error}.", + ), + ] + + +async def test_messages_on_retry( + receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name +) -> None: + """Test that messages are sent on retry.""" + helper = streaming.GrpcStreamBroadcaster( + stream_name="test_helper", + stream_method=lambda: _ErroringAsyncIter( + mock_error(), + receiver_ready_event, + ), + transform=_transformer, + retry_strategy=retry.LinearBackoff( + limit=1, + interval=0.01, ), + retry_on_exhausted_stream=True, + ) + + items: list[str] = [] + events: list[StreamEvent] = [] + async with AsyncExitStack() as stack: + stack.push_async_callback(helper.stop) + + receiver = helper.new_receiver() + receiver_ready_event.set() + items, events = await _split_message(receiver) + + assert items == [] + assert [type(e) for e in events] == [ + type(e) + for e in [ + StreamStartedEvent(), + StreamStoppedEvent( + exception=mock_error(), retry_time=timedelta(seconds=0.01) + ), + StreamStartedEvent(), + StreamStoppedEvent(exception=mock_error(), retry_time=None), + ] ]