Skip to content

Commit 8b2b087

Browse files
committed
Add a matches method to Receiver
This is an alternative for the `selected_from` method and could eventually replace it. Signed-off-by: Sahas Subramanian <[email protected]>
1 parent f92e1ef commit 8b2b087

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

src/frequenz/channels/_receiver.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,14 @@
155155

156156
from abc import ABC, abstractmethod
157157
from collections.abc import Callable
158-
from typing import Generic, Self
158+
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
159159

160160
from ._exceptions import Error
161161
from ._generic import MappedMessageT_co, ReceiverMessageT_co
162162

163+
if TYPE_CHECKING:
164+
from ._select import Selected
165+
163166

164167
class Receiver(ABC, Generic[ReceiverMessageT_co]):
165168
"""An endpoint to receive messages."""
@@ -284,6 +287,30 @@ def filter(
284287
"""
285288
return _Filter(receiver=self, filter_function=filter_function)
286289

290+
def matches(
291+
self, selected: Selected[Any]
292+
) -> TypeGuard[Selected[ReceiverMessageT_co]]:
293+
"""Check whether this receiver was selected by [`select()`][frequenz.channels.select].
294+
295+
This method is used in conjunction with the
296+
[`Selected`][frequenz.channels.Selected] class to determine which receiver was
297+
selected in `select()` iteration.
298+
299+
It also works as a [type guard][typing_extensions.TypeIs] to narrow the type of the
300+
`Selected` instance to the type of the receiver.
301+
302+
Please see [`select()`][frequenz.channels.select] for an example.
303+
304+
Args:
305+
selected: The result of a `select()` iteration.
306+
307+
Returns:
308+
Whether this receiver was selected.
309+
"""
310+
if handled := selected._recv is self: # pylint: disable=protected-access
311+
selected._handled = True # pylint: disable=protected-access
312+
return handled
313+
287314

288315
class ReceiverError(Error, Generic[ReceiverMessageT_co]):
289316
"""An error that originated in a [Receiver][frequenz.channels.Receiver].
@@ -373,9 +400,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
373400
ReceiverStoppedError: If the receiver stopped producing messages.
374401
ReceiverError: If there is a problem with the receiver.
375402
"""
376-
return self._mapping_function(
377-
self._receiver.consume()
378-
) # pylint: disable=protected-access
403+
return self._mapping_function(self._receiver.consume())
379404

380405
def __str__(self) -> str:
381406
"""Return a string representation of the mapper."""

0 commit comments

Comments
 (0)