Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
print('Received from recv2:', selected.message)
```

* `Receiver.filter()` can now properly handle `TypeGuard`s. The resulting receiver will now have the narrowed type when a `TypeGuard` is used.

## Bug Fixes

<!-- Here goes notable bug fixes that are worth a special mention or explanation -->
60 changes: 59 additions & 1 deletion src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,17 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload

from ._exceptions import Error
from ._generic import MappedMessageT_co, ReceiverMessageT_co

if TYPE_CHECKING:
from ._select import Selected

FilteredMessageT_co = TypeVar("FilteredMessageT_co", covariant=True)
"""Type variable for the filtered message type."""


class Receiver(ABC, Generic[ReceiverMessageT_co]):
"""An endpoint to receive messages."""
Expand Down Expand Up @@ -267,11 +270,66 @@ def map(
"""
return _Mapper(receiver=self, mapping_function=mapping_function)

@overload
def filter(
self,
filter_function: Callable[
[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]
],
/,
) -> Receiver[FilteredMessageT_co]:
"""Apply a type guard on the messages on a receiver.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
filter_function: The function to be applied on incoming messages to
determine if they should be received.

Returns:
A new receiver that only receives messages that pass the filter.
"""
... # pylint: disable=unnecessary-ellipsis

@overload
def filter(
self, filter_function: Callable[[ReceiverMessageT_co], bool], /
) -> Receiver[ReceiverMessageT_co]:
"""Apply a filter function on the messages on a receiver.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
not part of the `Receiver` interface you should save a reference to the
original receiver and use that instead.

Args:
filter_function: The function to be applied on incoming messages to
determine if they should be received.

Returns:
A new receiver that only receives messages that pass the filter.
"""
... # pylint: disable=unnecessary-ellipsis

def filter(
self,
filter_function: (
Callable[[ReceiverMessageT_co], bool]
| Callable[[ReceiverMessageT_co], TypeGuard[FilteredMessageT_co]]
),
/,
) -> Receiver[ReceiverMessageT_co] | Receiver[FilteredMessageT_co]:
"""Apply a filter function on the messages on a receiver.

Note:
You can pass a [type guard][typing.TypeGuard] as the filter function to
narrow the type of the messages that pass the filter.

Tip:
The returned receiver type won't have all the methods of the original
receiver. If you need to access methods of the original receiver that are
Expand Down
26 changes: 26 additions & 0 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import asyncio
from dataclasses import dataclass
from typing import TypeGuard, assert_never

import pytest

Expand Down Expand Up @@ -248,6 +249,31 @@ async def test_broadcast_filter() -> None:
assert (await receiver.receive()) == 15


async def test_broadcast_filter_type_guard() -> None:
"""Ensure filter type guard works."""
chan = Broadcast[int | str](name="input-chan")
sender = chan.new_sender()

def _is_int(num: int | str) -> TypeGuard[int]:
return isinstance(num, int)

# filter out objects that are not integers.
receiver = chan.new_receiver().filter(_is_int)

await sender.send("hello")
await sender.send(8)

message = await receiver.receive()
assert message == 8
is_int = False
match message:
case int():
is_int = True
case unexpected:
assert_never(unexpected)
assert is_int


async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int](name="input-chan")
Expand Down
Loading