|
155 | 155 |
|
156 | 156 | from abc import ABC, abstractmethod |
157 | 157 | from collections.abc import Callable |
158 | | -from typing import Generic, Self |
| 158 | +from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard |
159 | 159 |
|
160 | 160 | from ._exceptions import Error |
161 | 161 | from ._generic import MappedMessageT_co, ReceiverMessageT_co |
162 | 162 |
|
| 163 | +if TYPE_CHECKING: |
| 164 | + from ._select import Selected |
| 165 | + |
163 | 166 |
|
164 | 167 | class Receiver(ABC, Generic[ReceiverMessageT_co]): |
165 | 168 | """An endpoint to receive messages.""" |
@@ -281,6 +284,30 @@ def filter( |
281 | 284 | """ |
282 | 285 | return _Filter(receiver=self, filter_function=filter_function) |
283 | 286 |
|
| 287 | + def is_selected( |
| 288 | + self, selected: Selected[Any] |
| 289 | + ) -> TypeGuard[Selected[ReceiverMessageT_co]]: |
| 290 | + """Check whether this receiver was selected by [`select()`][frequenz.channels.select]. |
| 291 | +
|
| 292 | + This method is used in conjunction with the |
| 293 | + [`Selected`][frequenz.channels.Selected] class to determine which receiver was |
| 294 | + selected in `select()` iteration. |
| 295 | +
|
| 296 | + It also works as a [type guard][typing_extensions.TypeIs] to narrow the type of the |
| 297 | + `Selected` instance to the type of the receiver. |
| 298 | +
|
| 299 | + Please see [`select()`][frequenz.channels.select] for an example. |
| 300 | +
|
| 301 | + Args: |
| 302 | + selected: The result of a `select()` iteration. |
| 303 | +
|
| 304 | + Returns: |
| 305 | + Whether this receiver was selected. |
| 306 | + """ |
| 307 | + if handled := selected._recv is self: # pylint: disable=protected-access |
| 308 | + selected._handled = True # pylint: disable=protected-access |
| 309 | + return handled |
| 310 | + |
284 | 311 |
|
285 | 312 | class ReceiverError(Error, Generic[ReceiverMessageT_co]): |
286 | 313 | """An error that originated in a [Receiver][frequenz.channels.Receiver]. |
@@ -370,9 +397,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 |
370 | 397 | ReceiverStoppedError: If the receiver stopped producing messages. |
371 | 398 | ReceiverError: If there is a problem with the receiver. |
372 | 399 | """ |
373 | | - return self._mapping_function( |
374 | | - self._receiver.consume() |
375 | | - ) # pylint: disable=protected-access |
| 400 | + return self._mapping_function(self._receiver.consume()) |
376 | 401 |
|
377 | 402 | def __str__(self) -> str: |
378 | 403 | """Return a string representation of the mapper.""" |
|
0 commit comments