Skip to content

Commit 402338c

Browse files
committed
Handle type guards properly in Receiver.filter()
Now the `Receiver` type returned by `Receiver.filter()` will have the narrowed type when a `TypeGuard` is used. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 486d1be commit 402338c

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed

RELEASE_NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
print('Received from recv2:', selected.message)
2121
```
2222

23+
* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used.
24+
2325
## Bug Fixes
2426

2527
<!-- Here goes notable bug fixes that are worth a special mention or explanation -->

src/frequenz/channels/_receiver.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,17 @@
155155

156156
from abc import ABC, abstractmethod
157157
from collections.abc import Callable
158-
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
158+
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload
159159

160160
from ._exceptions import Error
161161
from ._generic import MappedMessageT_co, ReceiverMessageT_co
162162

163163
if TYPE_CHECKING:
164164
from ._select import Selected
165165

166+
FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True)
167+
"""Type variable for the filtered message type."""
168+
166169

167170
class Receiver(ABC, Generic[ReceiverMessageT_co]):
168171
"""An endpoint to receive messages."""
@@ -267,11 +270,67 @@ def map(
267270
"""
268271
return _Mapper(receiver=self, mapping_function=mapping_function)
269272

273+
@overload
274+
def filter(
275+
self,
276+
filter_function: Callable[
277+
[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]
278+
],
279+
/,
280+
) -> Receiver[FilteredMessageT_co]:
281+
"""Apply a type guard on the messages on a receiver.
282+
283+
Tip:
284+
The returned receiver type won't have all the methods of the original
285+
receiver. If you need to access methods of the original receiver that are
286+
not part of the `Receiver` interface you should save a reference to the
287+
original receiver and use that instead.
288+
289+
Args:
290+
filter_function: The function to be applied on incoming messages to
291+
determine if they should be received.
292+
293+
Returns:
294+
A new receiver that only receives messages that pass the filter.
295+
"""
296+
... # pylint: disable=unnecessary-ellipsis
297+
298+
@overload
270299
def filter(
271300
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
272301
) -> Receiver[ReceiverMessageT_co]:
273302
"""Apply a filter function on the messages on a receiver.
274303
304+
Tip:
305+
The returned receiver type won't have all the methods of the original
306+
receiver. If you need to access methods of the original receiver that are
307+
not part of the `Receiver` interface you should save a reference to the
308+
original receiver and use that instead.
309+
310+
Args:
311+
filter_function: The function to be applied on incoming messages to
312+
determine if they should be received.
313+
314+
Returns:
315+
A new receiver that only receives messages that pass the filter.
316+
"""
317+
... # pylint: disable=unnecessary-ellipsis
318+
319+
# We need to use Any here because otherwise _Filter would have to deal with two
320+
# different signatures. We can create two filter classes, one for regular functions
321+
# and one for type guards, but then there is no way to tell at runtime which
322+
# function is a type guard and which isn't to instantiate the correct class.
323+
# Using Any here has no impact though, as thanks to the overloads, only the
324+
# overloaded types will be accepted.
325+
def filter(
326+
self, filter_function: Callable[[ReceiverMessageT_co], Any], /
327+
) -> Receiver[Any]:
328+
"""Apply a filter function on the messages on a receiver.
329+
330+
Note:
331+
You can pass a [type guard][typing.TypeGuard] as the filter function to
332+
narrow the type of the messages that pass the filter.
333+
275334
Tip:
276335
The returned receiver type won't have all the methods of the original
277336
receiver. If you need to access methods of the original receiver that are

tests/test_broadcast.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import asyncio
88
from dataclasses import dataclass
9+
from typing import TypeGuard, assert_never
910

1011
import pytest
1112

@@ -248,6 +249,29 @@ async def test_broadcast_filter() -> None:
248249
assert (await receiver.receive()) == 15
249250

250251

252+
async def test_broadcast_filter_type_guard() -> None:
253+
"""Ensure filter type guard works."""
254+
chan = Broadcast[int | str](name="input-chan")
255+
sender = chan.new_sender()
256+
257+
def _is_int(num: int | str) -> TypeGuard[int]:
258+
return isinstance(num, int)
259+
260+
# filter out all numbers less than 10.
261+
receiver = chan.new_receiver().filter(_is_int)
262+
263+
await sender.send("hello")
264+
await sender.send(8)
265+
266+
message = await receiver.receive()
267+
assert message == 8
268+
match message:
269+
case int():
270+
assert message == 8
271+
case unexpected:
272+
assert_never(unexpected)
273+
274+
251275
async def test_broadcast_receiver_drop() -> None:
252276
"""Ensure deleted receivers get cleaned up."""
253277
chan = Broadcast[int](name="input-chan")

0 commit comments

Comments
 (0)