From 9e1480a7682e4c3c86fc635e0942f3410fd580f9 Mon Sep 17 00:00:00 2001 From: "Mathias L. Baumann" Date: Mon, 2 Jun 2025 18:31:15 +0200 Subject: [PATCH] Add stream filter function Signed-off-by: Mathias L. Baumann --- RELEASE_NOTES.md | 1 + src/frequenz/client/base/streaming.py | 31 ++++++++++- tests/streaming/test_receiver_filter.py | 71 +++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/streaming/test_receiver_filter.py diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 9a57d26..f1e824b 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -26,6 +26,7 @@ case int() as output: print(f"Received message: {output}") ``` +* In the `streaming` module, the new function `filter_stream_events` can be used to filter out stream events and retain the old behavior. ## Bug Fixes diff --git a/src/frequenz/client/base/streaming.py b/src/frequenz/client/base/streaming.py index 8c5533d..0b0e24c 100644 --- a/src/frequenz/client/base/streaming.py +++ b/src/frequenz/client/base/streaming.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta -from typing import AsyncIterable, Generic, TypeAlias, TypeVar +from typing import AsyncIterable, Generic, Tuple, Type, TypeAlias, TypeGuard, TypeVar import grpc.aio @@ -58,6 +58,35 @@ class StreamFatalError: """Type alias for the events that can be sent over the stream.""" +FilteredOutputT = TypeVar("FilteredOutputT") +"""Type alias for the output type of the stream after filtering.""" + + +def filter_stream_events( + receiver: channels.Receiver[StreamEvent | FilteredOutputT], + ignore_events: Tuple[Type[StreamEvent], ...] = ( + StreamStarted, + StreamRetrying, + StreamFatalError, + ), +) -> channels.Receiver[FilteredOutputT]: + """Filter out specific stream events from the receiver. + + Args: + receiver: The receiver to filter. + ignore_events: A tuple of event types to filter out, by default all. + + Returns: + A new receiver that only returns the transformed output type. + """ + + def _filter(sample: FilteredOutputT | StreamEvent) -> TypeGuard[FilteredOutputT]: + """Check if the received message is of the output type.""" + return not isinstance(sample, ignore_events) + + return receiver.filter(_filter) + + class GrpcStreamBroadcaster(Generic[InputT, OutputT]): """Helper class to handle grpc streaming methods. diff --git a/tests/streaming/test_receiver_filter.py b/tests/streaming/test_receiver_filter.py new file mode 100644 index 0000000..982f7a7 --- /dev/null +++ b/tests/streaming/test_receiver_filter.py @@ -0,0 +1,71 @@ +# License: MIT +# Copyright © 2025 Frequenz Energy-as-a-Service GmbH + +"""Test filtering of stream events.""" + +import logging +from datetime import timedelta +from typing import Tuple, Type + +import pytest +from frequenz.channels import Broadcast + +from frequenz.client.base.streaming import ( + StreamEvent, + StreamFatalError, + StreamRetrying, + StreamStarted, + filter_stream_events, +) + + +@pytest.mark.parametrize( + "filter_events", + ( + (StreamStarted, StreamRetrying, StreamFatalError), + (StreamRetrying, StreamFatalError), + (StreamFatalError,), + (), + (StreamStarted, StreamRetrying), + ), +) +async def test_filter_stream_events( + filter_events: Tuple[Type[StreamEvent], ...], +) -> None: + """Test filtering all events.""" + channel = Broadcast[int | StreamEvent](name="FilterStreamEventsTestChannel") + + receiver = filter_stream_events(channel.new_receiver(), filter_events) + sender = channel.new_sender() + + events = ( + StreamStarted(), + 1, + 2, + 3, + StreamRetrying(delay=timedelta(seconds=1)), + 4, + 5, + 6, + StreamFatalError(exception=Exception("Test error")), + ) + + num_samples = 6 + num_received_samples = 0 + + for event in events: + logging.info("Sending event: %s", event) + await sender.send(event) + + await channel.close() + + async for event in receiver: + logging.info("Received event: %s", event) + if isinstance(event, int): + num_received_samples += 1 + else: + assert not isinstance( + event, filter_events + ), "Received unexpected event type" + + assert num_received_samples == num_samples, "Unexpected number of samples received"