Skip to content

Commit f46c9b3

Browse files
committed
Add take_while and drop_while to Receiver
The `take_while()` method is just an alias for `filter()`, with the intention to eliminate the ambiguity and make it more readable. On the other hand, `drop_while()` is the negation of `take_while()` and provided as a convenience and also to improve readability. These names are widely popular and used in other programming languages as well as the Python `itertools` module. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 90091ed commit f46c9b3

File tree

3 files changed

+233
-5
lines changed

3 files changed

+233
-5
lines changed

RELEASE_NOTES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## New Features
44

5+
- `Receiver`
6+
7+
* Add `take_while()` as a less ambiguous and more readable alternative to `filter()`.
8+
* Add `drop_while()` as a convenience and more readable alternative to `filter()` with a negated predicate.
9+
* The usage of `filter()` is discouraged in favor of `take_while()` and `drop_while()`.
10+
511
### Experimental
612

713
- A new predicate, `OnlyIfPrevious`, to `filter()` messages based on the previous message.

src/frequenz/channels/_receiver.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,23 @@
5858
# Message Filtering
5959
6060
If you need to filter the received messages, receivers provide a
61-
[`filter()`][frequenz.channels.Receiver.filter] method to easily do so:
61+
[`take_while()`][frequenz.channels.Receiver.take_while] and a
62+
[`drop_while()`][frequenz.channels.Receiver.drop_while]
63+
method to easily do so:
6264
6365
```python show_lines="6:"
6466
from frequenz.channels import Anycast
6567
6668
channel = Anycast[int](name="test-channel")
6769
receiver = channel.new_receiver()
6870
69-
async for message in receiver.filter(lambda x: x % 2 == 0):
71+
async for message in receiver.take_while(lambda x: x % 2 == 0):
7072
print(message) # Only even numbers will be printed
7173
```
7274
7375
As with [`map()`][frequenz.channels.Receiver.map],
74-
[`filter()`][frequenz.channels.Receiver.filter] returns a new full receiver, so you can
75-
use it in any of the ways described above.
76+
[`take_while()`][frequenz.channels.Receiver.take_while] returns a new full receiver, so
77+
you can use it in any of the ways described above.
7678
7779
# Error Handling
7880
@@ -280,6 +282,11 @@ def filter(
280282
) -> Receiver[FilteredMessageT_co]:
281283
"""Apply a type guard on the messages on a receiver.
282284
285+
Tip:
286+
It is recommended to use the
287+
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
288+
this one, as it makes the intention more clear.
289+
283290
Tip:
284291
The returned receiver type won't have all the methods of the original
285292
receiver. If you need to access methods of the original receiver that are
@@ -301,6 +308,11 @@ def filter(
301308
) -> Receiver[ReceiverMessageT_co]:
302309
"""Apply a filter function on the messages on a receiver.
303310
311+
Tip:
312+
It is recommended to use the
313+
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
314+
this one, as it makes the intention more clear.
315+
304316
Tip:
305317
The returned receiver type won't have all the methods of the original
306318
receiver. If you need to access methods of the original receiver that are
@@ -326,6 +338,11 @@ def filter(
326338
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
327339
"""Apply a filter function on the messages on a receiver.
328340
341+
Tip:
342+
It is recommended to use the
343+
[`take_while()`][frequenz.channels.Receiver.take_while] method instead of
344+
this one, as it makes the intention more clear.
345+
329346
Note:
330347
You can pass a [type guard][typing.TypeGuard] as the filter function to
331348
narrow the type of the messages that pass the filter.
@@ -345,6 +362,117 @@ def filter(
345362
"""
346363
return _Filter(receiver=self, filter_function=filter_function)
347364

365+
@overload
366+
def take_while(
367+
self,
368+
predicate: Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]],
369+
/,
370+
) -> Receiver[FilteredMessageT_co]:
371+
"""Take only the messages that fulfill a predicate, narrowing the type.
372+
373+
The returned receiver will only receive messages that fulfill the predicate
374+
(evaluates to `True`), and will drop messages that don't.
375+
376+
Tip:
377+
The returned receiver type won't have all the methods of the original
378+
receiver. If you need to access methods of the original receiver that are
379+
not part of the `Receiver` interface you should save a reference to the
380+
original receiver and use that instead.
381+
382+
Args:
383+
predicate: The predicate to be applied on incoming messages to
384+
determine if they should be taken.
385+
386+
Returns:
387+
A new receiver that only receives messages that fulfill the predicate.
388+
"""
389+
... # pylint: disable=unnecessary-ellipsis
390+
391+
@overload
392+
def take_while(
393+
self, predicate: Callable[[ReceiverMessageT_co], bool], /
394+
) -> Receiver[ReceiverMessageT_co]:
395+
"""Take only the messages that fulfill a predicate.
396+
397+
The returned receiver will only receive messages that fulfill the predicate
398+
(evaluates to `True`), and will drop messages that don't.
399+
400+
Tip:
401+
The returned receiver type won't have all the methods of the original
402+
receiver. If you need to access methods of the original receiver that are
403+
not part of the `Receiver` interface you should save a reference to the
404+
original receiver and use that instead.
405+
406+
Args:
407+
predicate: The predicate to be applied on incoming messages to
408+
determine if they should be taken.
409+
410+
Returns:
411+
A new receiver that only receives messages that fulfill the predicate.
412+
"""
413+
... # pylint: disable=unnecessary-ellipsis
414+
415+
def take_while(
416+
self,
417+
predicate: (
418+
Callable[[ReceiverMessageT_co], bool]
419+
| Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]]
420+
),
421+
/,
422+
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
423+
"""Take only the messages that fulfill a predicate.
424+
425+
The returned receiver will only receive messages that fulfill the predicate
426+
(evaluates to `True`), and will drop messages that don't.
427+
428+
Note:
429+
You can pass a [type guard][typing.TypeGuard] as the predicate to narrow the
430+
type of the received messages.
431+
432+
Tip:
433+
The returned receiver type won't have all the methods of the original
434+
receiver. If you need to access methods of the original receiver that are
435+
not part of the `Receiver` interface you should save a reference to the
436+
original receiver and use that instead.
437+
438+
Args:
439+
predicate: The predicate to be applied on incoming messages to
440+
determine if they should be taken.
441+
442+
Returns:
443+
A new receiver that only receives messages that fulfill the predicate.
444+
"""
445+
return _Filter(receiver=self, filter_function=predicate)
446+
447+
def drop_while(
448+
self,
449+
predicate: Callable[[ReceiverMessageT_co], bool],
450+
/,
451+
) -> Receiver[ReceiverMessageT_co] | Receiver[ReceiverMessageT_co]:
452+
"""Drop the messages that fulfill a predicate.
453+
454+
The returned receiver will drop messages that fulfill the predicate
455+
(evaluates to `True`), and receive messages that don't.
456+
457+
Tip:
458+
If you need to narrow the type of the received messages, you can use the
459+
[`take_while()`][frequenz.channels.Receiver.take_while] method instead.
460+
461+
Tip:
462+
The returned receiver type won't have all the methods of the original
463+
receiver. If you need to access methods of the original receiver that are
464+
not part of the `Receiver` interface you should save a reference to the
465+
original receiver and use that instead.
466+
467+
Args:
468+
predicate: The predicate to be applied on incoming messages to
469+
determine if they should be dropped.
470+
471+
Returns:
472+
A new receiver that only receives messages that don't fulfill the predicate.
473+
"""
474+
return _Filter(receiver=self, filter_function=predicate, negate=True)
475+
348476
def triggered(
349477
self, selected: Selected[Any]
350478
) -> TypeGuard[Selected[ReceiverMessageT_co]]:
@@ -492,12 +620,14 @@ def __init__(
492620
*,
493621
receiver: Receiver[ReceiverMessageT_co],
494622
filter_function: Callable[[ReceiverMessageT_co], bool],
623+
negate: bool = False,
495624
) -> None:
496625
"""Initialize this receiver filter.
497626
498627
Args:
499628
receiver: The input receiver.
500629
filter_function: The function to apply on the input data.
630+
negate: Whether to negate the filter function.
501631
"""
502632
self._receiver: Receiver[ReceiverMessageT_co] = receiver
503633
"""The input receiver."""
@@ -507,6 +637,8 @@ def __init__(
507637

508638
self._next_message: ReceiverMessageT_co | _Sentinel = _SENTINEL
509639

640+
self._negate: bool = negate
641+
510642
self._recv_closed = False
511643

512644
async def ready(self) -> bool:
@@ -522,7 +654,10 @@ async def ready(self) -> bool:
522654
"""
523655
while await self._receiver.ready():
524656
message = self._receiver.consume()
525-
if self._filter_function(message):
657+
result = self._filter_function(message)
658+
if self._negate:
659+
result = not result
660+
if result:
526661
self._next_message = message
527662
return True
528663
self._recv_closed = True

tests/test_receiver.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# License: MIT
2+
# Copyright © 2024 Frequenz Energy-as-a-Service GmbH
3+
4+
"""Tests for the Receiver class."""
5+
6+
import asyncio
7+
from collections.abc import Sequence
8+
from typing import TypeGuard, assert_type
9+
10+
import pytest
11+
from typing_extensions import override
12+
13+
from frequenz.channels import Receiver, ReceiverError, ReceiverStoppedError
14+
15+
16+
class _MockReceiver(Receiver[int | str]):
17+
def __init__(self, messages: Sequence[int | str] = ()) -> None:
18+
self.messages = list(messages)
19+
self.stopped = False
20+
21+
@override
22+
async def ready(self) -> bool:
23+
"""Return True if there are messages to consume or the receiver is stopped."""
24+
if self.stopped:
25+
return False
26+
if self.messages:
27+
await asyncio.sleep(0)
28+
return True
29+
return False
30+
31+
@override
32+
def consume(self) -> int | str:
33+
"""Return the next message."""
34+
if self.stopped or not self.messages:
35+
raise ReceiverStoppedError(self)
36+
return self.messages.pop(0)
37+
38+
def stop(self) -> None:
39+
"""Stop the receiver."""
40+
self.stopped = True
41+
42+
43+
async def test_receiver_take_while() -> None:
44+
"""Test the take_while method."""
45+
receiver = _MockReceiver([1, 2, 3, 4, 5])
46+
47+
filtered_receiver = receiver.take_while(lambda x: x % 2 == 0)
48+
async with asyncio.timeout(1):
49+
assert await filtered_receiver.receive() == 2
50+
assert await filtered_receiver.receive() == 4
51+
52+
with pytest.raises(ReceiverStoppedError):
53+
await filtered_receiver.receive()
54+
55+
56+
async def test_receiver_take_type_guard() -> None:
57+
"""Test the take_while method using a TypeGuard."""
58+
receiver = _MockReceiver([1, "1", 2, "2", 3, "3", 4, 5])
59+
60+
def is_even(x: int | str) -> TypeGuard[int]:
61+
if not isinstance(x, int):
62+
return False
63+
return x % 2 == 0
64+
65+
filtered_receiver = receiver.take_while(is_even)
66+
async with asyncio.timeout(1):
67+
received = await filtered_receiver.receive()
68+
assert_type(received, int)
69+
assert received == 2
70+
assert await filtered_receiver.receive() == 4
71+
72+
with pytest.raises(ReceiverStoppedError):
73+
await filtered_receiver.receive()
74+
75+
76+
async def test_receiver_drop_while() -> None:
77+
"""Test the drop_while method."""
78+
receiver = _MockReceiver([1, 2, 3, 4, 5])
79+
80+
filtered_receiver = receiver.drop_while(lambda x: x % 2 == 0)
81+
async with asyncio.timeout(1):
82+
assert await filtered_receiver.receive() == 1
83+
assert await filtered_receiver.receive() == 3
84+
assert await filtered_receiver.receive() == 5
85+
86+
with pytest.raises(ReceiverStoppedError):
87+
await filtered_receiver.receive()

0 commit comments

Comments
 (0)