Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

## New Features

* The streaming client now also sends state change events out. Usage example:
* The streaming client, when using `new_receiver(include_events=True)`, will now return a receiver that yields stream notification events, such as `StreamStarted`, `StreamRetrying`, and `StreamFatalError`. This allows you to monitor the state of the stream:

```python
recv = streamer.new_receiver()
```python
recv = streamer.new_receiver(include_events=True)

for msg in recv:
match msg:
Expand Down
197 changes: 155 additions & 42 deletions src/frequenz/client/base/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from typing import AsyncIterable, Generic, TypeAlias, TypeVar
from typing import AsyncIterable, Generic, Literal, TypeAlias, TypeVar, overload

import grpc.aio

Expand Down Expand Up @@ -58,6 +58,7 @@ class StreamFatalError:
"""Type alias for the events that can be sent over the stream."""


# pylint: disable-next=too-many-instance-attributes
class GrpcStreamBroadcaster(Generic[InputT, OutputT]):
"""Helper class to handle grpc streaming methods.

Expand All @@ -69,36 +70,86 @@ class GrpcStreamBroadcaster(Generic[InputT, OutputT]):
with the `stop` method. New receivers can be created with the
`new_receiver` method, which will receive the streamed messages.

Additionally to the transformed messages, the broadcaster will also send
state change messages indicating whether the stream is connecting,
connected, or disconnected. These messages can be used to monitor the
state of the stream.
If `include_events=True` is passed to `new_receiver`, the receiver will
also get state change messages (`StreamStarted`, `StreamRetrying`,
`StreamFatalError`) indicating the state of the stream.

Example:
```python
from frequenz.client.base import GrpcStreamBroadcaster
from frequenz.client.base import (
GrpcStreamBroadcaster,
StreamFatalError,
StreamRetrying,
StreamStarted,
)
from frequenz.channels import Receiver # Assuming Receiver is available

# Dummy async iterable for demonstration
async def async_range(fail_after: int = -1) -> AsyncIterable[int]:
for i in range(10):
if fail_after != -1 and i >= fail_after:
raise grpc.aio.AioRpcError(
code=grpc.StatusCode.UNAVAILABLE,
initial_metadata=grpc.aio.Metadata(),
trailing_metadata=grpc.aio.Metadata(),
details="Simulated error"
)
yield i
await asyncio.sleep(0.1)

async def main():
streamer = GrpcStreamBroadcaster(
stream_name="example_stream",
stream_method=lambda: async_range(fail_after=3),
transform=lambda msg: msg * 2, # transform messages
retry_on_exhausted_stream=False,
)

def async_range() -> AsyncIterable[int]:
yield from range(10)
# Receiver for data only
data_recv: Receiver[int] = streamer.new_receiver()

streamer = GrpcStreamBroadcaster(
stream_name="example_stream",
stream_method=async_range,
transform=lambda msg: msg,
)
# Receiver for data and events
mixed_recv: Receiver[int | StreamEvent] = streamer.new_receiver(
include_events=True
)

recv = streamer.new_receiver()

for msg in recv:
match msg:
case StreamStarted():
print("Stream started")
case StreamRetrying(delay, error):
print(f"Stream stopped and will retry in {delay}: {error or 'closed'}")
case StreamFatalError(error):
print(f"Stream will stop because of a fatal error: {error}")
case int() as output:
print(f"Received message: {output}")
async def consume_mixed():
async for msg in mixed_recv:
match msg:
case StreamStarted():
print("Mixed: Stream started")
case StreamRetrying(delay, error):
print(
"Mixed: Stream retrying in " +
f"{delay.total_seconds():.1f}s: {error or 'closed'}"
)
case StreamFatalError(error):
print(f"Mixed: Stream fatal error: {error}")
break # Stop consuming on fatal error
case int() as output:
print(f"Mixed: Received data: {output}")
if isinstance(msg, StreamFatalError):
break
print("Mixed: Consumer finished")


async def consume_data():
async for data_msg in data_recv:
print(f"DataOnly: Received data: {data_msg}")
print("DataOnly: Consumer finished")

mixed_consumer_task = asyncio.create_task(consume_mixed())
data_consumer_task = asyncio.create_task(consume_data())

await asyncio.sleep(5) # Let it run for a bit
print("Stopping streamer...")
await streamer.stop()
await mixed_consumer_task
await data_consumer_task
print("Streamer stopped.")

if __name__ == "__main__":
asyncio.run(main())
```
"""

Expand Down Expand Up @@ -130,27 +181,77 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
)
self._retry_on_exhausted_stream = retry_on_exhausted_stream

self._channel: channels.Broadcast[StreamEvent | OutputT] = channels.Broadcast(
name=f"GrpcStreamBroadcaster-{stream_name}"
# Channel for transformed data messages (OutputT)
self._data_channel: channels.Broadcast[OutputT] = channels.Broadcast(
name=f"GrpcStreamBroadcaster-{stream_name}-Data"
)

# Channel for stream events (StreamEvent), created on demand
self._event_channel: channels.Broadcast[StreamEvent] | None = None
self._event_sender: channels.Sender[StreamEvent] | None = None
self._task = asyncio.create_task(self._run())

@overload
def new_receiver(
self,
maxsize: int = 50,
warn_on_overflow: bool = True,
*,
include_events: Literal[False] = False,
) -> channels.Receiver[OutputT]: ...

@overload
def new_receiver(
self,
maxsize: int = 50,
warn_on_overflow: bool = True,
*,
include_events: Literal[True],
) -> channels.Receiver[StreamEvent | OutputT]: ...

def new_receiver(
self, maxsize: int = 50, warn_on_overflow: bool = True
) -> channels.Receiver[StreamEvent | OutputT]:
self,
maxsize: int = 50,
warn_on_overflow: bool = True,
*,
Comment on lines +214 to +216
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really part of this PR, but since you are adding the * here it makes me think that we should probably put all other arguments as keyword-only, as it can't be really inferred what they are by the type.

include_events: bool = False,
) -> channels.Receiver[OutputT] | channels.Receiver[StreamEvent | OutputT]:
"""Create a new receiver for the stream.

Args:
maxsize: The maximum number of messages to buffer.
warn_on_overflow: Whether to log a warning when the receiver's
maxsize: The maximum number of messages to buffer in underlying receivers.
warn_on_overflow: Whether to log a warning when a receiver's
buffer is full and a message is dropped.
include_events: Whether to include stream events (e.g. StreamStarted,
StreamRetrying, StreamFatalError) in the receiver. If `False` (default),
only transformed data messages will be received.

Returns:
A new receiver.
A new receiver. If `include_events` is True, the receiver will yield
both `OutputT` and `StreamEvent` types. Otherwise, only `OutputT`.
"""
return self._channel.new_receiver(
if not include_events:
return self._data_channel.new_receiver(
limit=maxsize, warn_on_overflow=warn_on_overflow
)

if self._event_channel is None:
_logger.debug(
"%s: First request for events, creating event channel.",
self._stream_name,
)
self._event_channel = channels.Broadcast[StreamEvent](
name=f"GrpcStreamBroadcaster-{self._stream_name}-Events"
)
self._event_sender = self._event_channel.new_sender()

data_rx = self._data_channel.new_receiver(
limit=maxsize, warn_on_overflow=warn_on_overflow
)
event_rx = self._event_channel.new_receiver(
limit=maxsize, warn_on_overflow=warn_on_overflow
)
return channels.merge(data_rx, event_rx)

@property
def is_running(self) -> bool:
Expand All @@ -171,28 +272,35 @@ async def stop(self) -> None:
await self._task
except asyncio.CancelledError:
pass
Copy link

Copilot AI Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for closing the data and event channels is duplicated in both the stop and _run methods. Consider refactoring this duplicate code into a helper function to reduce repetition and improve maintainability.

Suggested change
pass
pass
await self._close_channels()
"""Close the data and event channels."""

Copilot uses AI. Check for mistakes.
await self._channel.close()
await self._data_channel.close()
if self._event_channel is not None:
await self._event_channel.close()

async def _run(self) -> None:
"""Run the streaming helper."""
sender = self._channel.new_sender()
data_sender = self._data_channel.new_sender()

while True:
error: Exception | None = None
_logger.info("%s: starting to stream", self._stream_name)
try:
call = self._stream_method()
await sender.send(StreamStarted())

if self._event_sender:
await self._event_sender.send(StreamStarted())

async for msg in call:
await sender.send(self._transform(msg))
await data_sender.send(self._transform(msg))
except grpc.aio.AioRpcError as err:
error = err

if error is None and not self._retry_on_exhausted_stream:
_logger.info(
"%s: connection closed, stream exhausted", self._stream_name
)
await self._channel.close()
await self._data_channel.close()
if self._event_channel is not None:
await self._event_channel.close()
break

interval = self._retry_strategy.next_interval()
Expand All @@ -204,9 +312,11 @@ async def _run(self) -> None:
self._retry_strategy.get_progress(),
error_str,
)
if error is not None:
await sender.send(StreamFatalError(error))
await self._channel.close()
if error is not None and self._event_sender:
await self._event_sender.send(StreamFatalError(error))
await self._data_channel.close()
if self._event_channel is not None:
await self._event_channel.close()
break
_logger.warning(
"%s: connection ended, retrying %s in %0.3f seconds. %s.",
Expand All @@ -216,5 +326,8 @@ async def _run(self) -> None:
error_str,
)

await sender.send(StreamRetrying(timedelta(seconds=interval), error))
if self._event_sender:
await self._event_sender.send(
StreamRetrying(timedelta(seconds=interval), error)
)
await asyncio.sleep(interval)
29 changes: 20 additions & 9 deletions tests/streaming/test_grpc_stream_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import AsyncIterator
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import Literal
from unittest import mock

import grpc
Expand Down Expand Up @@ -260,9 +261,11 @@ async def test_streaming_error( # pylint: disable=too-many-arguments
]


@pytest.mark.parametrize("include_events", [True, False])
async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name
caplog: pytest.LogCaptureFixture,
include_events: Literal[True] | Literal[False],
) -> None:
"""Test retry logic when next_interval returns 0."""
caplog.set_level(logging.WARNING)
Expand All @@ -279,14 +282,17 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
)

items: list[str] = []
events: list[StreamEvent] = []
async with AsyncExitStack() as stack:
stack.push_async_callback(helper.stop)

receiver = helper.new_receiver()
receiver = helper.new_receiver(include_events=include_events)
receiver_ready_event.set()
items, _ = await _split_message(receiver)
items, events = await _split_message(receiver)

assert not items
assert bool(events) == include_events

assert mock_retry.next_interval.mock_calls == [mock.call(), mock.call()]
assert caplog.record_tuples == [
(
Expand All @@ -304,8 +310,10 @@ async def test_retry_next_interval_zero( # pylint: disable=too-many-arguments
]


@pytest.mark.parametrize("include_events", [True, False])
async def test_messages_on_retry(
receiver_ready_event: asyncio.Event, # pylint: disable=redefined-outer-name
include_events: Literal[True] | Literal[False],
) -> None:
"""Test that messages are sent on retry."""
# We need to use a specific instance for all the test here because 2 errors created
Expand All @@ -328,7 +336,7 @@ async def test_messages_on_retry(
async with AsyncExitStack() as stack:
stack.push_async_callback(helper.stop)

receiver = helper.new_receiver()
receiver = helper.new_receiver(include_events=include_events)
receiver_ready_event.set()
items, events = await _split_message(receiver)

Expand All @@ -338,9 +346,12 @@ async def test_messages_on_retry(
"transformed_0",
"transformed_1",
]
assert events == [
StreamStarted(),
StreamRetrying(timedelta(seconds=0.0), error),
StreamStarted(),
StreamFatalError(error),
]
if include_events:
assert events == [
StreamStarted(),
StreamRetrying(timedelta(seconds=0.0), error),
StreamStarted(),
StreamFatalError(error),
]
else:
assert events == []
Loading