diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index f70d38b..9c33130 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -11,7 +11,7 @@ ## New Features - +- `GrpcStreamBroadcaster` now supports restarting a stream once the retry strategy is exhausted. The new method `start()` can be used for this, but it will also implicitly restart when `new_receiver()` is called on an exhausted stream. ## Bug Fixes diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index 99db426..cd1b86e 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -50,11 +50,24 @@ def __init__( self._retry_strategy = ( retry.LinearBackoff() if retry_strategy is None else retry_strategy.copy() ) + self._task: asyncio.Task[None] | None = None + self._channel: channels.Broadcast[OutputT] - self._channel: channels.Broadcast[OutputT] = channels.Broadcast( - name=f"GrpcStreamBroadcaster-{stream_name}" - ) + self.start() + + def start(self) -> None: + """Start the streaming helper. + + Should be called after the channel was closed to restart the stream. + """ + if self._task is not None and not self._task.done(): + return + + self._retry_strategy.reset() self._task = asyncio.create_task(self._run()) + self._channel = channels.Broadcast( + name=f"GrpcStreamBroadcaster-{self._stream_name}" + ) def new_receiver(self, maxsize: int = 50) -> channels.Receiver[OutputT]: """Create a new receiver for the stream. @@ -65,11 +78,17 @@ def new_receiver(self, maxsize: int = 50) -> channels.Receiver[OutputT]: Returns: A new receiver. """ + if self._channel.is_closed: + _logger.warning( + "%s: stream has stopped, starting a new one.", self._stream_name + ) + self.start() + return self._channel.new_receiver(limit=maxsize) async def stop(self) -> None: """Stop the streaming helper.""" - if self._task.done(): + if not self._task or self._task.done(): return self._task.cancel() try: diff --git a/tests/streaming/test_grpc_stream_broadcaster.py b/tests/streaming/test_grpc_stream_broadcaster.py index 7d33466..1014670 100644 --- a/tests/streaming/test_grpc_stream_broadcaster.py +++ b/tests/streaming/test_grpc_stream_broadcaster.py @@ -237,3 +237,101 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments f"giving up. Error: {expected_error_str}.", ), ] + + +async def test_new_receiver_after_error( + no_retry: mock.MagicMock, # pylint: disable=redefined-outer-name + receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that creating a new receiver after an error restarts the stream.""" + caplog.set_level(logging.INFO) + error = grpc.aio.AioRpcError( + code=_NamedMagicMock(name="mock grpc code"), + initial_metadata=mock.MagicMock(), + trailing_metadata=mock.MagicMock(), + details="mock details", + debug_error_string="mock debug_error_string", + ) + # Use the no_retry strategy + helper = streaming.GrpcStreamBroadcaster( + stream_name="test_helper", + stream_method=lambda: _ErroringAsyncIter( + error, receiver_ready_event, num_successes=1 + ), + transform=_transformer, + retry_strategy=no_retry, + ) + + items: list[str] = [] + async with AsyncExitStack() as stack: + stack.push_async_callback(helper.stop) + + receiver = helper.new_receiver() + receiver_ready_event.set() + # Consume the first item before the error occurs + async for item in receiver: + items.append(item) + + # Wait for the helper's task to complete + assert helper._task + await helper._task + assert helper._task.done() + + # At this point, the stream has ended due to the error + # Now, create a new receiver after the error + with mock.patch.object(helper, "start", wraps=helper.start) as mock_start: + receiver = helper.new_receiver() + # Ensure that helper.start() is called when the channel is closed + mock_start.assert_called_once() + + # Reset the event to allow the new stream to proceed + receiver_ready_event.clear() + receiver_ready_event.set() + async for item in receiver: + items.append(item) + + # Verify that items from both streams are collected + assert items == ["transformed_0", "transformed_0"] + + # Optionally, verify the logging output + expected_logs = [ + ( + "frequenz.client.base.streaming", + logging.INFO, + "test_helper: starting to stream", + ), + ( + "frequenz.client.base.streaming", + logging.ERROR, + "test_helper: connection ended, retry limit exceeded (mock progress), " + "giving up. Error: " + ".", + ), + ( + "frequenz.client.base.streaming", + logging.WARNING, + "test_helper: stream has stopped, starting a new one.", + ), + ( + "frequenz.client.base.streaming", + logging.INFO, + "test_helper: starting to stream", + ), + ( + "frequenz.client.base.streaming", + logging.ERROR, + "test_helper: connection ended, retry limit exceeded (mock progress), " + "giving up. Error: " + ".", + ), + ] + assert caplog.record_tuples == expected_logs