diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index b2bc245..9a57d26 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -13,16 +13,18 @@ * The streaming client now also sends state change events out. Usage example: ```python - recv = streamer.new_receiver() - - for msg in recv: - match msg: - case StreamStartedEvent(): - print("Stream started") - case StreamStoppedEvent() as event: - print(f"Stream stopped, reason {event.exception}, retry in {event.retry_time}") - case int() as output: - print(f"Received message: {output}") + recv = streamer.new_receiver() + + for msg in recv: + match msg: + case StreamStarted(): + print("Stream started") + case StreamRetrying(delay, error): + print(f"Stream stopped and will retry in {delay}: {error or 'closed'}") + case StreamFatalError(error): + print(f"Stream will stop because of a fatal error: {error}") + case int() as output: + print(f"Received message: {output}") ``` ## Bug Fixes diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index 2e95bcd..8c5533d 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -26,28 +26,38 @@ """The output type of the stream.""" -@dataclass(frozen=True, kw_only=True) -class StreamStartedEvent: +@dataclass(frozen=True) +class StreamStarted: """Event indicating that the stream has started.""" -@dataclass(frozen=True, kw_only=True) -class StreamStoppedEvent: +@dataclass(frozen=True) +class StreamRetrying: """Event indicating that the stream has stopped.""" - retry_time: timedelta | None = None - """Time to wait before retrying the stream, if applicable.""" + delay: timedelta + """Time to wait before retrying to start the stream again.""" exception: Exception | None = None - """The exception that caused the stream to stop, if any.""" + """The exception that caused the stream to stop, if any. + If `None`, the stream was stopped without an error, e.g. the server closed the + stream. + """ + + +@dataclass(frozen=True) +class StreamFatalError: + """Event indicating that the stream has stopped due to an unrecoverable error.""" + + exception: Exception + """The exception that caused the stream to stop.""" -StreamEvent: TypeAlias = StreamStartedEvent | StreamStoppedEvent + +StreamEvent: TypeAlias = StreamStarted | StreamRetrying | StreamFatalError """Type alias for the events that can be sent over the stream.""" -# Ignore D412: "No blank lines allowed between a section header and its content" -# flake8: noqa: D412 class GrpcStreamBroadcaster(Generic[InputT, OutputT]): """Helper class to handle grpc streaming methods. @@ -65,30 +75,31 @@ class GrpcStreamBroadcaster(Generic[InputT, OutputT]): state of the stream. Example: + ```python + from frequenz.client.base import GrpcStreamBroadcaster + + def async_range() -> AsyncIterable[int]: + yield from range(10) + + streamer = GrpcStreamBroadcaster( + stream_name="example_stream", + stream_method=async_range, + transform=lambda msg: msg, + ) - ```python - from frequenz.client.base import GrpcStreamBroadcaster - - def async_range() -> AsyncIterable[int]: - yield from range(10) - - streamer = GrpcStreamBroadcaster( - stream_name="example_stream", - stream_method=async_range, - transform=lambda msg: msg, - ) - - recv = streamer.new_receiver() - - for msg in recv: - match msg: - case StreamStartedEvent(): - print("Stream started") - case StreamStoppedEvent() as event: - print(f"Stream stopped, reason {event.exception}, retry in {event.retry_time}") - case int() as output: - print(f"Received message: {output}") - ``` + recv = streamer.new_receiver() + + for msg in recv: + match msg: + case StreamStarted(): + print("Stream started") + case StreamRetrying(delay, error): + print(f"Stream stopped and will retry in {delay}: {error or 'closed'}") + case StreamFatalError(error): + print(f"Stream will stop because of a fatal error: {error}") + case int() as output: + print(f"Received message: {output}") + ``` """ def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments @@ -104,7 +115,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument Args: stream_name: A name to identify the stream in the logs. stream_method: A function that returns the grpc stream. This function is - called everytime the connection is lost and we want to retry. + called every time the connection is lost and we want to retry. transform: A function to transform the input type to the output type. retry_strategy: The retry strategy to use, when the connection is lost. Defaults to retries every 3 seconds, with a jitter of 1 second, indefinitely. @@ -171,27 +182,20 @@ async def _run(self) -> None: _logger.info("%s: starting to stream", self._stream_name) try: call = self._stream_method() - await sender.send(StreamStartedEvent()) + await sender.send(StreamStarted()) async for msg in call: await sender.send(self._transform(msg)) except grpc.aio.AioRpcError as err: error = err - interval = self._retry_strategy.next_interval() - - await sender.send( - StreamStoppedEvent( - retry_time=timedelta(seconds=interval) if interval else None, - exception=error, - ) - ) - if error is None and not self._retry_on_exhausted_stream: _logger.info( "%s: connection closed, stream exhausted", self._stream_name ) await self._channel.close() break + + interval = self._retry_strategy.next_interval() error_str = f"Error: {error}" if error else "Stream exhausted" if interval is None: _logger.error( @@ -200,6 +204,8 @@ async def _run(self) -> None: self._retry_strategy.get_progress(), error_str, ) + if error is not None: + await sender.send(StreamFatalError(error)) await self._channel.close() break _logger.warning( @@ -209,4 +215,6 @@ async def _run(self) -> None: interval, error_str, ) + + await sender.send(StreamRetrying(timedelta(seconds=interval), error)) await asyncio.sleep(interval) diff --git a/tests/streaming/test_grpc_stream_broadcaster.py b/tests/streaming/test_grpc_stream_broadcaster.py index a9a0bef..b96fadf 100644 --- a/tests/streaming/test_grpc_stream_broadcaster.py +++ b/tests/streaming/test_grpc_stream_broadcaster.py @@ -10,6 +10,7 @@ from datetime import timedelta from unittest import mock +import grpc import grpc.aio import pytest from frequenz.channels import Receiver @@ -17,8 +18,9 @@ from frequenz.client.base import retry, streaming from frequenz.client.base.streaming import ( StreamEvent, - StreamStartedEvent, - StreamStoppedEvent, + StreamFatalError, + StreamRetrying, + StreamStarted, ) @@ -43,12 +45,12 @@ def no_retry() -> mock.MagicMock: return mock_retry -def mock_error() -> grpc.aio.AioRpcError: +def make_error() -> grpc.aio.AioRpcError: """Mock error for testing.""" return grpc.aio.AioRpcError( - code=mock.MagicMock(name="mock grpc code"), - initial_metadata=mock.MagicMock(), - trailing_metadata=mock.MagicMock(), + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), details="mock details", debug_error_string="mock debug_error_string", ) @@ -95,7 +97,7 @@ async def _split_message( events: list[StreamEvent] = [] async for item in receiver: match item: - case StreamStartedEvent() | StreamStoppedEvent() as item: + case StreamStarted() | StreamRetrying() | StreamFatalError(): events.append(item) case str(): items.append(item) @@ -147,9 +149,7 @@ async def test_streaming_success_retry_on_exhausted( "transformed_3", "transformed_4", ] - assert events == [ - StreamStoppedEvent(exception=None, retry_time=None), - ] + assert events == [] assert caplog.record_tuples == [ ( @@ -180,7 +180,7 @@ async def test_streaming_success( receiver_ready_event.set() items, events = await _split_message(receiver) - no_retry.next_interval.assert_called_once_with() + no_retry.next_interval.assert_not_called() assert items == [ "transformed_0", @@ -189,9 +189,7 @@ async def test_streaming_success( "transformed_3", "transformed_4", ] - assert events == [ - StreamStoppedEvent(exception=None, retry_time=None), - ] + assert events == [] assert caplog.record_tuples == [ ( "frequenz.client.base.streaming", @@ -221,7 +219,7 @@ async def test_streaming_error( # pylint: disable=too-many-arguments """Test streaming errors.""" caplog.set_level(logging.INFO) - error = mock_error() + error = make_error() helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", @@ -268,7 +266,7 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments ) -> None: """Test retry logic when next_interval returns 0.""" caplog.set_level(logging.WARNING) - error = mock_error() + error = make_error() mock_retry = mock.MagicMock(spec=retry.Strategy) mock_retry.next_interval.side_effect = [0, None] mock_retry.copy.return_value = mock_retry @@ -310,17 +308,18 @@ async def test_messages_on_retry( receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name ) -> None: """Test that messages are sent on retry.""" + # We need to use a specific instance for all the test here because 2 errors created + # with the same arguments don't compare equal (grpc.aio.AioRpcError doesn't seem to + # provide a __eq__ method). + error = make_error() + helper = streaming.GrpcStreamBroadcaster( stream_name="test_helper", stream_method=lambda: _ErroringAsyncIter( - mock_error(), - receiver_ready_event, + error, receiver_ready_event, num_successes=2 ), transform=_transformer, - retry_strategy=retry.LinearBackoff( - limit=1, - interval=0.01, - ), + retry_strategy=retry.LinearBackoff(limit=1, interval=0.0, jitter=0.0), retry_on_exhausted_stream=True, ) @@ -333,15 +332,15 @@ async def test_messages_on_retry( receiver_ready_event.set() items, events = await _split_message(receiver) - assert items == [] - assert [type(e) for e in events] == [ - type(e) - for e in [ - StreamStartedEvent(), - StreamStoppedEvent( - exception=mock_error(), retry_time=timedelta(seconds=0.01) - ), - StreamStartedEvent(), - StreamStoppedEvent(exception=mock_error(), retry_time=None), - ] + assert items == [ + "transformed_0", + "transformed_1", + "transformed_0", + "transformed_1", + ] + assert events == [ + StreamStarted(), + StreamRetrying(timedelta(seconds=0.0), error), + StreamStarted(), + StreamFatalError(error), ]