diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index fd3ec90f..9b31be10 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -2,6 +2,12 @@ ## New Features +- `Receiver` + + * Add `take_while()` as a less ambiguous and more readable alternative to `filter()`. + * Add `drop_while()` as a convenience and more readable alternative to `filter()` with a negated predicate. + * The usage of `filter()` is discouraged in favor of `take_while()` and `drop_while()`. + ### Experimental - A new predicate, `OnlyIfPrevious`, to `filter()` messages based on the previous message. diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 53862a45..25367d09 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -58,7 +58,9 @@ # Message Filtering If you need to filter the received messages, receivers provide a -[`filter()`][frequenz.channels.Receiver.filter] method to easily do so: +[`take_while()`][frequenz.channels.Receiver.take_while] and a +[`drop_while()`][frequenz.channels.Receiver.drop_while] +method to easily do so: ```python show_lines="6:" from frequenz.channels import Anycast @@ -66,13 +68,17 @@ channel = Anycast[int](name="test-channel") receiver = channel.new_receiver() -async for message in receiver.filter(lambda x: x % 2 == 0): +async for message in receiver.take_while(lambda x: x % 2 == 0): print(message) # Only even numbers will be printed ``` As with [`map()`][frequenz.channels.Receiver.map], -[`filter()`][frequenz.channels.Receiver.filter] returns a new full receiver, so you can -use it in any of the ways described above. +[`take_while()`][frequenz.channels.Receiver.take_while] returns a new full receiver, so +you can use it in any of the ways described above. + +[`take_while()`][frequenz.channels.Receiver.take_while] can even receive a +[type guard][typing.TypeGuard] as the predicate to narrow the type of the received +messages. # Error Handling @@ -280,6 +286,11 @@ def filter( ) -> Receiver[FilteredMessageT_co]: """Apply a type guard on the messages on a receiver. + Tip: + It is recommended to use the + [`take_while()`][frequenz.channels.Receiver.take_while] method instead of + this one, as it makes the intention more clear. + Tip: The returned receiver type won't have all the methods of the original receiver. If you need to access methods of the original receiver that are @@ -301,6 +312,11 @@ def filter( ) -> Receiver[ReceiverMessageT_co]: """Apply a filter function on the messages on a receiver. + Tip: + It is recommended to use the + [`take_while()`][frequenz.channels.Receiver.take_while] method instead of + this one, as it makes the intention more clear. + Tip: The returned receiver type won't have all the methods of the original receiver. If you need to access methods of the original receiver that are @@ -326,6 +342,11 @@ def filter( ) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]: """Apply a filter function on the messages on a receiver. + Tip: + It is recommended to use the + [`take_while()`][frequenz.channels.Receiver.take_while] method instead of + this one, as it makes the intention more clear. + Note: You can pass a [type guard][typing.TypeGuard] as the filter function to narrow the type of the messages that pass the filter. @@ -345,6 +366,117 @@ def filter( """ return _Filter(receiver=self, filter_function=filter_function) + @overload + def take_while( + self, + predicate: Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]], + /, + ) -> Receiver[FilteredMessageT_co]: + """Take only the messages that fulfill a predicate, narrowing the type. + + The returned receiver will only receive messages that fulfill the predicate + (evaluates to `True`), and will drop messages that don't. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + predicate: The predicate to be applied on incoming messages to + determine if they should be taken. + + Returns: + A new receiver that only receives messages that fulfill the predicate. + """ + ... # pylint: disable=unnecessary-ellipsis + + @overload + def take_while( + self, predicate: Callable[[ReceiverMessageT_co], bool], / + ) -> Receiver[ReceiverMessageT_co]: + """Take only the messages that fulfill a predicate. + + The returned receiver will only receive messages that fulfill the predicate + (evaluates to `True`), and will drop messages that don't. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + predicate: The predicate to be applied on incoming messages to + determine if they should be taken. + + Returns: + A new receiver that only receives messages that fulfill the predicate. + """ + ... # pylint: disable=unnecessary-ellipsis + + def take_while( + self, + predicate: ( + Callable[[ReceiverMessageT_co], bool] + | Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]] + ), + /, + ) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]: + """Take only the messages that fulfill a predicate. + + The returned receiver will only receive messages that fulfill the predicate + (evaluates to `True`), and will drop messages that don't. + + Note: + You can pass a [type guard][typing.TypeGuard] as the predicate to narrow the + type of the received messages. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + predicate: The predicate to be applied on incoming messages to + determine if they should be taken. + + Returns: + A new receiver that only receives messages that fulfill the predicate. + """ + return _Filter(receiver=self, filter_function=predicate) + + def drop_while( + self, + predicate: Callable[[ReceiverMessageT_co], bool], + /, + ) -> Receiver[ReceiverMessageT_co] | Receiver[ReceiverMessageT_co]: + """Drop the messages that fulfill a predicate. + + The returned receiver will drop messages that fulfill the predicate + (evaluates to `True`), and receive messages that don't. + + Tip: + If you need to narrow the type of the received messages, you can use the + [`take_while()`][frequenz.channels.Receiver.take_while] method instead. + + Tip: + The returned receiver type won't have all the methods of the original + receiver. If you need to access methods of the original receiver that are + not part of the `Receiver` interface you should save a reference to the + original receiver and use that instead. + + Args: + predicate: The predicate to be applied on incoming messages to + determine if they should be dropped. + + Returns: + A new receiver that only receives messages that don't fulfill the predicate. + """ + return _Filter(receiver=self, filter_function=predicate, negate=True) + def triggered( self, selected: Selected[Any] ) -> TypeGuard[Selected[ReceiverMessageT_co]]: @@ -492,12 +624,14 @@ def __init__( *, receiver: Receiver[ReceiverMessageT_co], filter_function: Callable[[ReceiverMessageT_co], bool], + negate: bool = False, ) -> None: """Initialize this receiver filter. Args: receiver: The input receiver. filter_function: The function to apply on the input data. + negate: Whether to negate the filter function. """ self._receiver: Receiver[ReceiverMessageT_co] = receiver """The input receiver.""" @@ -507,6 +641,8 @@ def __init__( self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL + self._negate: bool = negate + self._recv_closed = False async def ready(self) -> bool: @@ -522,7 +658,10 @@ async def ready(self) -> bool: """ while await self._receiver.ready(): message = self._receiver.consume() - if self._filter_function(message): + result = self._filter_function(message) + if self._negate: + result = not result + if result: self._next_message = message return True self._recv_closed = True diff --git a/tests/test_anycast.py b/tests/test_anycast.py index c6db0d9a..cae3d451 100644 --- a/tests/test_anycast.py +++ b/tests/test_anycast.py @@ -162,58 +162,3 @@ async def test_anycast_none_messages() -> None: await sender.send(10) assert await receiver.receive() == 10 - - -async def test_anycast_async_iterator() -> None: - """Check that the anycast receiver works as an async iterator.""" - acast: Anycast[str] = Anycast(name="test") - - sender = acast.new_sender() - receiver = acast.new_receiver() - - async def send_messages() -> None: - for val in ["one", "two", "three", "four", "five"]: - await sender.send(val) - await acast.close() - - sender_task = asyncio.create_task(send_messages()) - - received = [] - async for recv in receiver: - received.append(recv) - - assert received == ["one", "two", "three", "four", "five"] - - await sender_task - - -async def test_anycast_map() -> None: - """Ensure map runs on all incoming messages.""" - chan: Anycast[int] = Anycast(name="test") - sender = chan.new_sender() - - # transform int receiver into bool receiver. - receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10) - - await sender.send(8) - await sender.send(12) - - assert (await receiver.receive()) is False - assert (await receiver.receive()) is True - - -async def test_anycast_filter() -> None: - """Ensure filter keeps only the messages that pass the filter.""" - chan = Anycast[int](name="input-chan") - sender = chan.new_sender() - - # filter out all numbers less than 10. - receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10) - - await sender.send(8) - await sender.send(12) - await sender.send(5) - await sender.send(15) - - assert (await receiver.receive()) == 12 - assert (await receiver.receive()) == 15 diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index c8a2e9cf..b23686c2 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -6,7 +6,6 @@ import asyncio from dataclasses import dataclass -from typing import TypeGuard, assert_never import pytest @@ -194,86 +193,6 @@ async def test_broadcast_no_resend_latest() -> None: assert await new_recv.receive() == 100 -async def test_broadcast_async_iterator() -> None: - """Check that the broadcast receiver works as an async iterator.""" - bcast: Broadcast[int] = Broadcast(name="iter_test") - - sender = bcast.new_sender() - receiver = bcast.new_receiver() - - async def send_messages() -> None: - for val in range(0, 10): - await sender.send(val) - await bcast.close() - - sender_task = asyncio.create_task(send_messages()) - - received = [] - async for recv in receiver: - received.append(recv) - - assert received == list(range(0, 10)) - - await sender_task - - -async def test_broadcast_map() -> None: - """Ensure map runs on all incoming messages.""" - chan = Broadcast[int](name="input-chan") - sender = chan.new_sender() - - # transform int receiver into bool receiver. - receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10) - - await sender.send(8) - await sender.send(12) - - assert (await receiver.receive()) is False - assert (await receiver.receive()) is True - - -async def test_broadcast_filter() -> None: - """Ensure filter keeps only the messages that pass the filter.""" - chan = Broadcast[int](name="input-chan") - sender = chan.new_sender() - - # filter out all numbers less than 10. - receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10) - - await sender.send(8) - await sender.send(12) - await sender.send(5) - await sender.send(15) - - assert (await receiver.receive()) == 12 - assert (await receiver.receive()) == 15 - - -async def test_broadcast_filter_type_guard() -> None: - """Ensure filter type guard works.""" - chan = Broadcast[int | str](name="input-chan") - sender = chan.new_sender() - - def _is_int(num: int | str) -> TypeGuard[int]: - return isinstance(num, int) - - # filter out objects that are not integers. - receiver = chan.new_receiver().filter(_is_int) - - await sender.send("hello") - await sender.send(8) - - message = await receiver.receive() - assert message == 8 - is_int = False - match message: - case int(): - is_int = True - case unexpected: - assert_never(unexpected) - assert is_int - - async def test_broadcast_receiver_drop() -> None: """Ensure deleted receivers get cleaned up.""" chan = Broadcast[int](name="input-chan") diff --git a/tests/test_receiver.py b/tests/test_receiver.py new file mode 100644 index 00000000..a517b0af --- /dev/null +++ b/tests/test_receiver.py @@ -0,0 +1,146 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Tests for the Receiver class.""" + +import asyncio +from collections.abc import Sequence +from typing import TypeGuard, assert_type +from unittest.mock import MagicMock + +import pytest +from typing_extensions import override + +from frequenz.channels import Receiver, ReceiverError, ReceiverStoppedError + + +class _MockReceiver(Receiver[int | str]): + def __init__(self, messages: Sequence[int | str] = ()) -> None: + self.messages = list(messages) + self.stopped = False + + @override + async def ready(self) -> bool: + """Return True if there are messages to consume or the receiver is stopped.""" + if self.stopped: + return False + if self.messages: + await asyncio.sleep(0) + return True + return False + + @override + def consume(self) -> int | str: + """Return the next message.""" + if self.stopped or not self.messages: + raise ReceiverStoppedError(self) + return self.messages.pop(0) + + def stop(self) -> None: + """Stop the receiver.""" + self.stopped = True + + +async def test_receiver_take_while() -> None: + """Test the take_while method.""" + receiver = _MockReceiver([1, 2, 3, 4, 5]) + + filtered_receiver = receiver.take_while(lambda x: x % 2 == 0) + async with asyncio.timeout(1): + assert await filtered_receiver.receive() == 2 + assert await filtered_receiver.receive() == 4 + + with pytest.raises(ReceiverStoppedError): + await filtered_receiver.receive() + + +async def test_receiver_take_type_guard() -> None: + """Test the take_while method using a TypeGuard.""" + receiver = _MockReceiver([1, "1", 2, "2", 3, "3", 4, 5]) + + def is_even(x: int | str) -> TypeGuard[int]: + if not isinstance(x, int): + return False + return x % 2 == 0 + + filtered_receiver = receiver.take_while(is_even) + async with asyncio.timeout(1): + received = await filtered_receiver.receive() + assert_type(received, int) + assert received == 2 + assert await filtered_receiver.receive() == 4 + + with pytest.raises(ReceiverStoppedError): + await filtered_receiver.receive() + + +async def test_receiver_drop_while() -> None: + """Test the drop_while method.""" + receiver = _MockReceiver([1, 2, 3, 4, 5]) + + filtered_receiver = receiver.drop_while(lambda x: x % 2 == 0) + async with asyncio.timeout(1): + assert await filtered_receiver.receive() == 1 + assert await filtered_receiver.receive() == 3 + assert await filtered_receiver.receive() == 5 + + with pytest.raises(ReceiverStoppedError): + await filtered_receiver.receive() + + +async def test_receiver_async_iteration() -> None: + """Test async iteration over the receiver.""" + receiver = _MockReceiver([1, 2]) + + received = [] + async with asyncio.timeout(1): + async for message in receiver: + received.append(message) + + assert received == [1, 2] + + +async def test_receiver_map() -> None: + """Test mapping a function over the receiver's messages.""" + receiver = _MockReceiver([1, 2]) + + mapped_receiver = receiver.map(lambda x: f"{x} + 1") + assert await mapped_receiver.receive() == "1 + 1" + assert await mapped_receiver.receive() == "2 + 1" + + +async def test_receiver_filter() -> None: + """Test filtering the receiver's messages.""" + receiver = _MockReceiver([1, 2, 3, 4, 5]) + + filtered_receiver = receiver.filter(lambda x: x % 2 == 0) + async with asyncio.timeout(1): + assert await filtered_receiver.receive() == 2 + assert await filtered_receiver.receive() == 4 + + with pytest.raises(ReceiverStoppedError): + await filtered_receiver.receive() + + +async def test_receiver_triggered() -> None: + """Test the triggered method.""" + receiver = _MockReceiver() + selected = MagicMock() + selected._recv = receiver # pylint: disable=protected-access + + assert receiver.triggered(selected) + assert selected._handled # pylint: disable=protected-access + + +async def test_receiver_error_handling() -> None: + """Test error handling in the receiver.""" + receiver = _MockReceiver([1]) + receiver.stop() + + async with asyncio.timeout(1): + with pytest.raises(ReceiverStoppedError): + await receiver.receive() + + receiver = _MockReceiver() + with pytest.raises(ReceiverError): + await receiver.receive()