Skip to content

Commit faa70c3

Browse files
committed
Limit Select to work only on Receivers
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 382ebef commit faa70c3

File tree

2 files changed

+21
-63
lines changed

2 files changed

+21
-63
lines changed

src/frequenz/channels/select.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
# License: MIT
22
# Copyright © 2022 Frequenz Energy-as-a-Service GmbH
33

4-
"""Select the first among multiple AsyncIterators.
4+
"""Select the first among multiple Receivers.
55
6-
Expects AsyncIterator class to raise `StopAsyncIteration`
6+
Expects Receiver class to raise `StopAsyncIteration`
77
exception once no more messages are expected or the channel
88
is closed in case of `Receiver` class.
99
"""
1010

1111
import asyncio
1212
import logging
1313
from dataclasses import dataclass
14-
from typing import Any, AsyncIterator, Dict, List, Optional, Set, TypeVar
14+
from typing import Any, Dict, List, Optional, Set, TypeVar
15+
16+
from frequenz.channels.base_classes import Receiver
1517

1618
logger = logging.Logger(__name__)
1719
T = TypeVar("T")
@@ -29,17 +31,17 @@ class _Selected:
2931

3032

3133
class Select:
32-
"""Select the next available message from a group of AsyncIterators.
34+
"""Select the next available message from a group of Receivers.
3335
34-
If `Select` was created with more `AsyncIterator` than what are read in
36+
If `Select` was created with more `Receiver` than what are read in
3537
the if-chain after each call to [ready()][frequenz.channels.Select.ready],
36-
messages coming in the additional async iterators are dropped, and
38+
messages coming in the additional receivers are dropped, and
3739
a warning message is logged.
3840
39-
[Receiver][frequenz.channels.Receiver]s also function as `AsyncIterator`.
41+
[Receiver][frequenz.channels.Receiver]s also function as `Receiver`.
4042
4143
Example:
42-
For example, if there are two async iterators that you want to
44+
For example, if there are two receivers that you want to
4345
simultaneously wait on, this can be done with:
4446
4547
```python
@@ -58,11 +60,11 @@ class Select:
5860
```
5961
"""
6062

61-
def __init__(self, **kwargs: AsyncIterator[Any]) -> None:
63+
def __init__(self, **kwargs: Receiver[Any]) -> None:
6264
"""Create a `Select` instance.
6365
6466
Args:
65-
**kwargs: sequence of async iterators
67+
**kwargs: sequence of receivers
6668
"""
6769
self._receivers = kwargs
6870
self._pending: Set[asyncio.Task[Any]] = set()
@@ -84,10 +86,10 @@ def __del__(self) -> None:
8486
task.cancel()
8587

8688
async def ready(self) -> bool:
87-
"""Wait until there is a message in any of the async iterators.
89+
"""Wait until there is a message in any of the receivers.
8890
8991
Returns `True` if there is a message available, and `False` if all
90-
async iterators have closed.
92+
receivers have closed.
9193
9294
Returns:
9395
Whether there are further messages or not.
@@ -102,7 +104,7 @@ async def ready(self) -> bool:
102104
self._ready_count = 0
103105
self._prev_ready_count = 0
104106
logger.warning(
105-
"Select.ready() dropped data from async iterator(s): %s, "
107+
"Select.ready() dropped data from receiver(s): %s, "
106108
"because no messages have been fetched since the last call to ready().",
107109
dropped_names,
108110
)
@@ -127,7 +129,7 @@ async def ready(self) -> bool:
127129
result = item.result()
128130
self._ready_count += 1
129131
self._result[name] = _Selected(result)
130-
# if channel or AsyncIterator is closed
132+
# if channel or Receiver is closed
131133
# don't add a task for it again.
132134
if result is None:
133135
continue
@@ -138,13 +140,13 @@ async def ready(self) -> bool:
138140
return True
139141

140142
def __getattr__(self, name: str) -> Optional[Any]:
141-
"""Return the latest unread message from a `AsyncIterator`, if available.
143+
"""Return the latest unread message from a `Receiver`, if available.
142144
143145
Args:
144146
name: Name of the channel.
145147
146148
Returns:
147-
Latest unread message for the specified `AsyncIterator`, or `None`.
149+
Latest unread message for the specified `Receiver`, or `None`.
148150
149151
Raises:
150152
KeyError: when the name was not specified when creating the

tests/test_select.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,67 +4,34 @@
44
"""Tests for the Select implementation."""
55

66
import asyncio
7-
from asyncio import Queue
8-
from typing import List, Optional
7+
from typing import List
98

109
from frequenz.channels import Anycast, Select, Sender
1110

1211

13-
class AsyncIterable:
14-
"""Example AsyncIterator class"""
15-
16-
def __init__(self) -> None:
17-
self.queue: "Queue[int]" = Queue()
18-
self.done = False
19-
20-
def __aiter__(self) -> "AsyncIterable":
21-
return self
22-
23-
async def __anext__(self) -> Optional[int]:
24-
if not self.queue.empty():
25-
return self.queue.get_nowait()
26-
if self.done:
27-
raise StopAsyncIteration
28-
msg = await self.queue.get()
29-
return msg
30-
31-
async def add(self, msg: int) -> bool:
32-
"""Adds object to iterator"""
33-
await self.queue.put(msg)
34-
35-
return True
36-
37-
3812
async def test_select() -> None:
3913
"""Ensure select receives messages in order."""
4014
chan1 = Anycast[int]()
4115
chan2 = Anycast[int]()
4216
chan3 = Anycast[int]()
4317

44-
async def send(
45-
ch1: Sender[int], ch2: Sender[int], ch3: Sender[int], queue: AsyncIterable
46-
) -> None:
18+
async def send(ch1: Sender[int], ch2: Sender[int], ch3: Sender[int]) -> None:
4719
for ctr in range(5):
4820
await ch1.send(ctr + 1)
4921
await ch2.send(ctr + 101)
5022
await ch3.send(ctr + 201)
51-
await queue.add(ctr + 301)
5223
await chan1.close()
5324
await ch2.send(1000)
5425
await chan2.close()
5526
await chan3.close()
56-
queue.done = True
57-
58-
queue = AsyncIterable()
5927

6028
senders = asyncio.create_task(
61-
send(chan1.get_sender(), chan2.get_sender(), chan3.get_sender(), queue),
29+
send(chan1.get_sender(), chan2.get_sender(), chan3.get_sender()),
6230
)
6331
select = Select(
6432
ch1=chan1.get_receiver(),
6533
ch2=chan2.get_receiver(),
6634
ch3=chan3.get_receiver(),
67-
ch4=queue,
6835
)
6936

7037
# only check for messages from all iterators but `ch3`.
@@ -82,32 +49,21 @@ async def send(
8249
results.append(val)
8350
else:
8451
results.append(-2)
85-
elif item := select.ch4:
86-
if val := item.inner:
87-
results.append(val)
88-
else:
89-
results.append(-4)
9052
await senders
9153

9254
expected_results = [
9355
1,
9456
101,
95-
301,
9657
2,
9758
102,
98-
302,
9959
3,
10060
103,
101-
303,
10261
4,
10362
104,
104-
304,
10563
5,
10664
105,
107-
305,
10865
-1, # marks end of messages from channel 1
10966
1000,
110-
-4, # marks end of messages from channel 4
11167
-2, # marks end of messages from channel 2
11268
]
11369
assert results == expected_results

0 commit comments

Comments
 (0)