|
155 | 155 |
|
156 | 156 | from abc import ABC, abstractmethod |
157 | 157 | from collections.abc import Callable |
158 | | -from typing import Generic, Self |
| 158 | +from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard |
159 | 159 |
|
160 | 160 | from ._exceptions import Error |
161 | 161 | from ._generic import MappedMessageT_co, ReceiverMessageT_co |
162 | 162 |
|
| 163 | +if TYPE_CHECKING: |
| 164 | + from ._select import Selected |
| 165 | + |
163 | 166 |
|
164 | 167 | class Receiver(ABC, Generic[ReceiverMessageT_co]): |
165 | 168 | """An endpoint to receive messages.""" |
@@ -284,6 +287,30 @@ def filter( |
284 | 287 | """ |
285 | 288 | return _Filter(receiver=self, filter_function=filter_function) |
286 | 289 |
|
| 290 | + def matches( |
| 291 | + self, selected: Selected[Any] |
| 292 | + ) -> TypeGuard[Selected[ReceiverMessageT_co]]: |
| 293 | + """Check whether this receiver was selected by [`select()`][frequenz.channels.select]. |
| 294 | +
|
| 295 | + This method is used in conjunction with the |
| 296 | + [`Selected`][frequenz.channels.Selected] class to determine which receiver was |
| 297 | + selected in `select()` iteration. |
| 298 | +
|
| 299 | + It also works as a [type guard][typing_extensions.TypeIs] to narrow the type of the |
| 300 | + `Selected` instance to the type of the receiver. |
| 301 | +
|
| 302 | + Please see [`select()`][frequenz.channels.select] for an example. |
| 303 | +
|
| 304 | + Args: |
| 305 | + selected: The result of a `select()` iteration. |
| 306 | +
|
| 307 | + Returns: |
| 308 | + Whether this receiver was selected. |
| 309 | + """ |
| 310 | + if handled := selected._recv is self: # pylint: disable=protected-access |
| 311 | + selected._handled = True # pylint: disable=protected-access |
| 312 | + return handled |
| 313 | + |
287 | 314 |
|
288 | 315 | class ReceiverError(Error, Generic[ReceiverMessageT_co]): |
289 | 316 | """An error that originated in a [Receiver][frequenz.channels.Receiver]. |
@@ -373,9 +400,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 |
373 | 400 | ReceiverStoppedError: If the receiver stopped producing messages. |
374 | 401 | ReceiverError: If there is a problem with the receiver. |
375 | 402 | """ |
376 | | - return self._mapping_function( |
377 | | - self._receiver.consume() |
378 | | - ) # pylint: disable=protected-access |
| 403 | + return self._mapping_function(self._receiver.consume()) |
379 | 404 |
|
380 | 405 | def __str__(self) -> str: |
381 | 406 | """Return a string representation of the mapper.""" |
|
0 commit comments