Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
case int() as output:
print(f"Received message: {output}")
```
* In the `streaming` module, the new function `filter_stream_events` can be used to filter out stream events and retain the old behavior.

## Bug Fixes

Expand Down
31 changes: 30 additions & 1 deletion src/frequenz/client/base/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from typing import AsyncIterable, Generic, TypeAlias, TypeVar
from typing import AsyncIterable, Generic, Tuple, Type, TypeAlias, TypeGuard, TypeVar

import grpc.aio

Expand Down Expand Up @@ -58,6 +58,35 @@ class StreamFatalError:
"""Type alias for the events that can be sent over the stream."""


FilteredOutputT = TypeVar("FilteredOutputT")
"""Type alias for the output type of the stream after filtering."""


def filter_stream_events(
receiver: channels.Receiver[StreamEvent | FilteredOutputT],
ignore_events: Tuple[Type[StreamEvent], ...] = (
StreamStarted,
StreamRetrying,
StreamFatalError,
),
) -> channels.Receiver[FilteredOutputT]:
"""Filter out specific stream events from the receiver.

Args:
receiver: The receiver to filter.
ignore_events: A tuple of event types to filter out, by default all.

Returns:
A new receiver that only returns the transformed output type.
"""

def _filter(sample: FilteredOutputT | StreamEvent) -> TypeGuard[FilteredOutputT]:
"""Check if the received message is of the output type."""
return not isinstance(sample, ignore_events)

return receiver.filter(_filter)


class GrpcStreamBroadcaster(Generic[InputT, OutputT]):
"""Helper class to handle grpc streaming methods.

Expand Down
71 changes: 71 additions & 0 deletions tests/streaming/test_receiver_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# License: MIT
# Copyright © 2025 Frequenz Energy-as-a-Service GmbH

"""Test filtering of stream events."""

import logging
from datetime import timedelta
from typing import Tuple, Type

import pytest
from frequenz.channels import Broadcast

from frequenz.client.base.streaming import (
StreamEvent,
StreamFatalError,
StreamRetrying,
StreamStarted,
filter_stream_events,
)


@pytest.mark.parametrize(
"filter_events",
(
(StreamStarted, StreamRetrying, StreamFatalError),
(StreamRetrying, StreamFatalError),
(StreamFatalError,),
(),
(StreamStarted, StreamRetrying),
),
)
async def test_filter_stream_events(
filter_events: Tuple[Type[StreamEvent], ...],
) -> None:
"""Test filtering all events."""
channel = Broadcast[int | StreamEvent](name="FilterStreamEventsTestChannel")

receiver = filter_stream_events(channel.new_receiver(), filter_events)
sender = channel.new_sender()

events = (
StreamStarted(),
1,
2,
3,
StreamRetrying(delay=timedelta(seconds=1)),
4,
5,
6,
StreamFatalError(exception=Exception("Test error")),
)

num_samples = 6
num_received_samples = 0

for event in events:
logging.info("Sending event: %s", event)
await sender.send(event)

await channel.close()

async for event in receiver:
logging.info("Received event: %s", event)
if isinstance(event, int):
num_received_samples += 1
else:
assert not isinstance(
event, filter_events
), "Received unexpected event type"

assert num_received_samples == num_samples, "Unexpected number of samples received"
Loading