Skip to content

Commit 779e807

Browse files
authored
Support filtering the messages on a receiver (#303)
A new `filter` method is added to the `Receiver` interface, which allows the application of a filter function on the messages on a receiver. Example: ```python async for message in receiver.filter(lambda num: num % 2): print(f"An even number: {message}") ```
2 parents c7e6096 + 068b04e commit 779e807

File tree

4 files changed

+148
-2
lines changed

4 files changed

+148
-2
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
- **Experimental**: `Pipe`, which provides a pipe between two channels, by connecting a `Receiver` to a `Sender`.
1818

19+
- `Receiver`s now have a `filter` method that applies a filter function on the messages on a receiver.
20+
1921
## Bug Fixes
2022

2123
<!-- Here goes notable bug fixes that are worth a special mention or explanation -->

src/frequenz/channels/_receiver.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,25 @@ def map(
242242
"""
243243
return _Mapper(receiver=self, mapping_function=mapping_function)
244244

245+
def filter(
246+
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
247+
) -> Receiver[ReceiverMessageT_co]:
248+
"""Apply a filter function on the messages on a receiver.
249+
250+
Tip:
251+
The returned receiver type won't have all the methods of the original
252+
receiver. If you need to access methods of the original receiver that are
253+
not part of the `Receiver` interface you should save a reference to the
254+
original receiver and use that instead.
255+
256+
Args:
257+
filter_function: The function to be applied on incoming messages.
258+
259+
Returns:
260+
A new receiver that applies the function on the received messages.
261+
"""
262+
return _Filter(receiver=self, filter_function=filter_function)
263+
245264

246265
class ReceiverError(Error, Generic[ReceiverMessageT_co]):
247266
"""An error that originated in a [Receiver][frequenz.channels.Receiver].
@@ -336,9 +355,100 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
336355
) # pylint: disable=protected-access
337356

338357
def __str__(self) -> str:
339-
"""Return a string representation of the timer."""
358+
"""Return a string representation of the mapper."""
340359
return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}"
341360

342361
def __repr__(self) -> str:
343-
"""Return a string representation of the timer."""
362+
"""Return a string representation of the mapper."""
344363
return f"{type(self).__name__}({self._receiver!r}, {self._mapping_function!r})"
364+
365+
366+
class _Sentinel:
367+
"""A sentinel object to represent no value received yet."""
368+
369+
def __str__(self) -> str:
370+
"""Return a string representation of this sentinel."""
371+
return "<No message ready to be consumed>"
372+
373+
def __repr__(self) -> str:
374+
"""Return a string representation of this sentinel."""
375+
return "<No message ready to be consumed>"
376+
377+
378+
_SENTINEL = _Sentinel()
379+
380+
381+
class _Filter(Receiver[ReceiverMessageT_co], Generic[ReceiverMessageT_co]):
382+
"""Apply a filter function on the messages on a receiver."""
383+
384+
def __init__(
385+
self,
386+
*,
387+
receiver: Receiver[ReceiverMessageT_co],
388+
filter_function: Callable[[ReceiverMessageT_co], bool],
389+
) -> None:
390+
"""Initialize this receiver filter.
391+
392+
Args:
393+
receiver: The input receiver.
394+
filter_function: The function to apply on the input data.
395+
"""
396+
self._receiver: Receiver[ReceiverMessageT_co] = receiver
397+
"""The input receiver."""
398+
399+
self._filter_function: Callable[[ReceiverMessageT_co], bool] = filter_function
400+
"""The function to apply on the input data."""
401+
402+
self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL
403+
404+
self._recv_closed = False
405+
406+
async def ready(self) -> bool:
407+
"""Wait until the receiver is ready with a message or an error.
408+
409+
Once a call to `ready()` has finished, the message should be read with
410+
a call to `consume()` (`receive()` or iterated over). The receiver will
411+
remain ready (this method will return immediately) until it is
412+
consumed.
413+
414+
Returns:
415+
Whether the receiver is still active.
416+
"""
417+
while await self._receiver.ready():
418+
message = self._receiver.consume()
419+
if self._filter_function(message):
420+
self._next_message = message
421+
return True
422+
self._recv_closed = True
423+
return False
424+
425+
def consume(self) -> ReceiverMessageT_co:
426+
"""Return a transformed message once `ready()` is complete.
427+
428+
Returns:
429+
The next message that was received.
430+
431+
Raises:
432+
ReceiverStoppedError: If the receiver stopped producing messages.
433+
ReceiverError: If there is a problem with the receiver.
434+
"""
435+
if self._recv_closed:
436+
raise ReceiverStoppedError(self)
437+
assert not isinstance(
438+
self._next_message, _Sentinel
439+
), "`consume()` must be preceded by a call to `ready()`"
440+
441+
message = self._next_message
442+
self._next_message = _SENTINEL
443+
return message
444+
445+
def __str__(self) -> str:
446+
"""Return a string representation of the filter."""
447+
return f"{type(self).__name__}:{self._receiver}:{self._filter_function}"
448+
449+
def __repr__(self) -> str:
450+
"""Return a string representation of the filter."""
451+
return (
452+
f"<{type(self).__name__} receiver={self._receiver!r} "
453+
f"filter={self._filter_function!r} next_message={self._next_message!r}>"
454+
)

tests/test_anycast.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,20 @@ async def test_anycast_map() -> None:
200200

201201
assert (await receiver.receive()) is False
202202
assert (await receiver.receive()) is True
203+
204+
205+
async def test_anycast_filter() -> None:
206+
"""Ensure filter keeps only the messages that pass the filter."""
207+
chan = Anycast[int](name="input-chan")
208+
sender = chan.new_sender()
209+
210+
# filter out all numbers less than 10.
211+
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)
212+
213+
await sender.send(8)
214+
await sender.send(12)
215+
await sender.send(5)
216+
await sender.send(15)
217+
218+
assert (await receiver.receive()) == 12
219+
assert (await receiver.receive()) == 15

tests/test_broadcast.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,23 @@ async def test_broadcast_map() -> None:
231231
assert (await receiver.receive()) is True
232232

233233

234+
async def test_broadcast_filter() -> None:
235+
"""Ensure filter keeps only the messages that pass the filter."""
236+
chan = Broadcast[int](name="input-chan")
237+
sender = chan.new_sender()
238+
239+
# filter out all numbers less than 10.
240+
receiver: Receiver[int] = chan.new_receiver().filter(lambda num: num > 10)
241+
242+
await sender.send(8)
243+
await sender.send(12)
244+
await sender.send(5)
245+
await sender.send(15)
246+
247+
assert (await receiver.receive()) == 12
248+
assert (await receiver.receive()) == 15
249+
250+
234251
async def test_broadcast_receiver_drop() -> None:
235252
"""Ensure deleted receivers get cleaned up."""
236253
chan = Broadcast[int](name="input-chan")

0 commit comments

Comments
 (0)