Skip to content

Commit 486d1be

Browse files
authored
Add a Receiver.triggered method (#328)
This is an alternative to the `selected_from` method.
2 parents 1eee0c8 + 22d76a8 commit 486d1be

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

RELEASE_NOTES.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010

1111
## New Features
1212

13-
<!-- Here goes the main new features and examples or instructions on how to use them -->
13+
- There is a new `Receiver.triggered` method that can be used instead of `selected_from`:
14+
15+
```python
16+
async for selected in select(recv1, recv2):
17+
if recv1.triggered(selected):
18+
print('Received from recv1:', selected.message)
19+
if recv2.triggered(selected):
20+
print('Received from recv2:', selected.message)
21+
```
1422

1523
## Bug Fixes
1624

src/frequenz/channels/_receiver.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,14 @@
155155

156156
from abc import ABC, abstractmethod
157157
from collections.abc import Callable
158-
from typing import Generic, Self
158+
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard
159159

160160
from ._exceptions import Error
161161
from ._generic import MappedMessageT_co, ReceiverMessageT_co
162162

163+
if TYPE_CHECKING:
164+
from ._select import Selected
165+
163166

164167
class Receiver(ABC, Generic[ReceiverMessageT_co]):
165168
"""An endpoint to receive messages."""
@@ -284,6 +287,30 @@ def filter(
284287
"""
285288
return _Filter(receiver=self, filter_function=filter_function)
286289

290+
def triggered(
291+
self, selected: Selected[Any]
292+
) -> TypeGuard[Selected[ReceiverMessageT_co]]:
293+
"""Check whether this receiver was selected by [`select()`][frequenz.channels.select].
294+
295+
This method is used in conjunction with the
296+
[`Selected`][frequenz.channels.Selected] class to determine which receiver was
297+
selected in `select()` iteration.
298+
299+
It also works as a [type guard][typing.TypeGuard] to narrow the type of the
300+
`Selected` instance to the type of the receiver.
301+
302+
Please see [`select()`][frequenz.channels.select] for an example.
303+
304+
Args:
305+
selected: The result of a `select()` iteration.
306+
307+
Returns:
308+
Whether this receiver was selected.
309+
"""
310+
if handled := selected._recv is self: # pylint: disable=protected-access
311+
selected._handled = True # pylint: disable=protected-access
312+
return handled
313+
287314

288315
class ReceiverError(Error, Generic[ReceiverMessageT_co]):
289316
"""An error that originated in a [Receiver][frequenz.channels.Receiver].
@@ -373,9 +400,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
373400
ReceiverStoppedError: If the receiver stopped producing messages.
374401
ReceiverError: If there is a problem with the receiver.
375402
"""
376-
return self._mapping_function(
377-
self._receiver.consume()
378-
) # pylint: disable=protected-access
403+
return self._mapping_function(self._receiver.consume())
379404

380405
def __str__(self) -> str:
381406
"""Return a string representation of the mapper."""

src/frequenz/channels/_select.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,7 @@ def selected_from(
269269
Returns:
270270
Whether the given receiver was selected.
271271
"""
272-
if handled := selected._recv is receiver: # pylint: disable=protected-access
273-
selected._handled = True # pylint: disable=protected-access
274-
return handled
272+
return receiver.triggered(selected)
275273

276274

277275
class SelectError(Error):
@@ -378,14 +376,14 @@ async def select( # noqa: DOC503
378376
import datetime
379377
from typing import assert_never
380378
381-
from frequenz.channels import ReceiverStoppedError, select, selected_from
379+
from frequenz.channels import ReceiverStoppedError, select
382380
from frequenz.channels.timer import SkipMissedAndDrift, Timer, TriggerAllMissed
383381
384382
timer1 = Timer(datetime.timedelta(seconds=1), TriggerAllMissed())
385383
timer2 = Timer(datetime.timedelta(seconds=0.5), SkipMissedAndDrift())
386384
387385
async for selected in select(timer1, timer2):
388-
if selected_from(selected, timer1):
386+
if timer1.triggered(selected):
389387
# Beware: `selected.message` might raise an exception, you can always
390388
# check for exceptions with `selected.exception` first or use
391389
# a try-except block. You can also quickly check if the receiver was
@@ -395,7 +393,7 @@ async def select( # noqa: DOC503
395393
continue
396394
print(f"timer1: now={datetime.datetime.now()} drift={selected.message}")
397395
timer2.stop()
398-
elif selected_from(selected, timer2):
396+
elif timer2.triggered(selected):
399397
# Explicitly handling of exceptions
400398
match selected.exception:
401399
case ReceiverStoppedError():

0 commit comments

Comments
 (0)