Skip to content

Commit 5b47b03

Browse files
Avoid dropping of messages after breaking from Select blocks (#42)
The current implementation of `Select.ready()`, instead of just checking that some of the receivers have messages, also fetches them. And in cases where there are multiple receivers with messages waiting to be read, `Select.ready()` consumes the latest message from each of them, and if subsequent user code decides to drop the `Select` object, and recreates it for some reason, the unprocessed messages that `Select.ready()` had fetched will get lost. This is incorrect usage of `Select`, but the implementation should still try to ensure messages don't get dropped in such cases.
2 parents 4cceb39 + 6138ee0 commit 5b47b03

File tree

16 files changed

+324
-191
lines changed

16 files changed

+324
-191
lines changed

benchmarks/benchmark_anycast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ async def benchmark_anycast(
5151
recv_trackers = [0]
5252

5353
async def update_tracker_on_receive(chan: Receiver[int]) -> None:
54-
while True:
55-
msg = await chan.receive()
56-
if msg is None:
57-
return
54+
async for _ in chan:
5855
recv_trackers[0] += 1
5956

6057
receivers = []

benchmarks/benchmark_broadcast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ async def benchmark_broadcast(
6969
recv_trackers = [0]
7070

7171
async def update_tracker_on_receive(chan: Receiver[int]) -> None:
72-
while True:
73-
msg = await chan.receive()
74-
if msg is None:
75-
return
72+
async for _ in chan:
7673
recv_trackers[0] += 1
7774

7875
receivers = []

src/frequenz/channels/anycast.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import deque
1010
from typing import Deque, Generic, Optional
1111

12+
from frequenz.channels.base_classes import ChannelClosedError
1213
from frequenz.channels.base_classes import Receiver as BaseReceiver
1314
from frequenz.channels.base_classes import Sender as BaseSender
1415
from frequenz.channels.base_classes import T
@@ -162,23 +163,36 @@ def __init__(self, chan: Anycast[T]) -> None:
162163
chan: A reference to the channel that this receiver belongs to.
163164
"""
164165
self._chan = chan
166+
self._next: Optional[T] = None
165167

166-
async def receive(self) -> Optional[T]:
167-
"""Receive a message from the channel.
168+
async def ready(self) -> None:
169+
"""Wait until the receiver is ready with a value.
168170
169-
Waits for an message to become available, and returns that message.
170-
When there are multiple receivers for the channel, only one receiver
171-
will receive each message.
172-
173-
Returns:
174-
`None`, if the channel is closed, a message otherwise.
171+
Raises:
172+
ChannelClosedError: if the underlying channel is closed.
175173
"""
174+
# if a message is already ready, then return immediately.
175+
if self._next is not None:
176+
return
177+
176178
while len(self._chan.deque) == 0:
177179
if self._chan.closed:
178-
return None
180+
raise ChannelClosedError()
179181
async with self._chan.recv_cv:
180182
await self._chan.recv_cv.wait()
181-
ret = self._chan.deque.popleft()
183+
self._next = self._chan.deque.popleft()
182184
async with self._chan.send_cv:
183185
self._chan.send_cv.notify(1)
184-
return ret
186+
187+
def consume(self) -> T:
188+
"""Return the latest value once `ready()` is complete.
189+
190+
Returns:
191+
The next value that was received.
192+
"""
193+
assert (
194+
self._next is not None
195+
), "calls to `consume()` must be follow a call to `ready()`"
196+
next_val = self._next
197+
self._next = None
198+
return next_val

src/frequenz/channels/base_classes.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,41 @@
66
from __future__ import annotations
77

88
from abc import ABC, abstractmethod
9-
from typing import Callable, Generic, Optional, TypeVar
9+
from typing import Any, Callable, Generic, Optional, TypeVar
1010

1111
T = TypeVar("T")
1212
U = TypeVar("U")
1313

1414

15+
class ChannelError(RuntimeError):
16+
"""Base channel error.
17+
18+
All exceptions generated by channels inherit from this exception.
19+
"""
20+
21+
def __init__(self, message: Any, channel: Any = None):
22+
"""Create a ChannelError instance.
23+
24+
Args:
25+
message: An error message.
26+
channel: A reference to the channel that encountered the error.
27+
"""
28+
super().__init__(message)
29+
self.channel: Any = channel
30+
31+
32+
class ChannelClosedError(ChannelError):
33+
"""Error raised when trying to operate on a closed channel."""
34+
35+
def __init__(self, channel: Any = None):
36+
"""Create a `ChannelClosedError` instance.
37+
38+
Args:
39+
channel: A reference to the channel that was closed.
40+
"""
41+
super().__init__(f"Channel {channel} was closed", channel)
42+
43+
1544
class Sender(ABC, Generic[T]):
1645
"""A channel Sender."""
1746

@@ -31,12 +60,43 @@ async def send(self, msg: T) -> bool:
3160
class Receiver(ABC, Generic[T]):
3261
"""A channel Receiver."""
3362

63+
async def __anext__(self) -> T:
64+
"""Await the next value in the async iteration over received values.
65+
66+
Returns:
67+
The next value received.
68+
69+
Raises:
70+
StopAsyncIteration: if the underlying channel is closed.
71+
"""
72+
try:
73+
await self.ready()
74+
return self.consume()
75+
except ChannelClosedError as exc:
76+
raise StopAsyncIteration() from exc
77+
3478
@abstractmethod
35-
async def receive(self) -> Optional[T]:
36-
"""Receive a message from the channel.
79+
async def ready(self) -> None:
80+
"""Wait until the receiver is ready with a value.
81+
82+
Once a call to `ready()` has finished, the value should be read with a call to
83+
`consume()`.
84+
85+
Raises:
86+
ChannelClosedError: if the underlying channel is closed.
87+
"""
88+
89+
@abstractmethod
90+
def consume(self) -> T:
91+
"""Return the latest value once `ready()` is complete.
92+
93+
`ready()` must be called before each call to `consume()`.
3794
3895
Returns:
39-
`None`, if the channel is closed, a message otherwise.
96+
The next value received.
97+
98+
Raises:
99+
ChannelClosedError: if the underlying channel is closed.
40100
"""
41101

42102
def __aiter__(self) -> Receiver[T]:
@@ -47,19 +107,19 @@ def __aiter__(self) -> Receiver[T]:
47107
"""
48108
return self
49109

50-
async def __anext__(self) -> T:
51-
"""Await the next value in the async iteration over received values.
52-
53-
Returns:
54-
The next value received.
110+
async def receive(self) -> T:
111+
"""Receive a message from the channel.
55112
56113
Raises:
57-
StopAsyncIteration: if we receive `None`, i.e. if the underlying
58-
channel is closed.
114+
ChannelClosedError: if the underlying channel is closed.
115+
116+
Returns:
117+
The received message.
59118
"""
60-
received = await self.receive()
61-
if received is None:
62-
raise StopAsyncIteration
119+
try:
120+
received = await self.__anext__() # pylint: disable=unnecessary-dunder-call
121+
except StopAsyncIteration as exc:
122+
raise ChannelClosedError() from exc
63123
return received
64124

65125
def map(self, call: Callable[[T], U]) -> Receiver[U]:
@@ -136,13 +196,14 @@ def __init__(self, recv: Receiver[T], transform: Callable[[T], U]) -> None:
136196
self._recv = recv
137197
self._transform = transform
138198

139-
async def receive(self) -> Optional[U]:
140-
"""Return a transformed message received from the input channel.
199+
async def ready(self) -> None:
200+
"""Wait until the receiver is ready with a value."""
201+
await self._recv.ready() # pylint: disable=protected-access
202+
203+
def consume(self) -> U:
204+
"""Return a transformed value once `ready()` is complete.
141205
142206
Returns:
143-
`None`, if the channel is closed, a message otherwise.
207+
The next value that was received.
144208
"""
145-
msg = await self._recv.receive()
146-
if msg is None:
147-
return None
148-
return self._transform(msg)
209+
return self._transform(self._recv.consume()) # pylint: disable=protected-access

src/frequenz/channels/bidirectional.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from __future__ import annotations
77

8-
from typing import Generic, Optional
8+
from typing import Generic
99

1010
from frequenz.channels.base_classes import Receiver, Sender, T, U
1111
from frequenz.channels.broadcast import Broadcast
@@ -82,10 +82,14 @@ async def send(self, msg: T) -> bool:
8282
"""
8383
return await self._sender.send(msg)
8484

85-
async def receive(self) -> Optional[U]:
86-
"""Receive a value from the other side.
85+
async def ready(self) -> None:
86+
"""Wait until the receiver is ready with a value."""
87+
await self._receiver.ready() # pylint: disable=protected-access
88+
89+
def consume(self) -> U:
90+
"""Return the latest value once `_ready` is complete.
8791
8892
Returns:
89-
Received value, or `None` if the channels are closed.
93+
The next value that was received.
9094
"""
91-
return await self._receiver.receive()
95+
return self._receiver.consume() # pylint: disable=protected-access

src/frequenz/channels/broadcast.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Deque, Dict, Generic, Optional
1313
from uuid import UUID, uuid4
1414

15-
from frequenz.channels.base_classes import BufferedReceiver
15+
from frequenz.channels.base_classes import BufferedReceiver, ChannelClosedError
1616
from frequenz.channels.base_classes import Peekable as BasePeekable
1717
from frequenz.channels.base_classes import Sender as BaseSender
1818
from frequenz.channels.base_classes import T
@@ -249,31 +249,33 @@ def __len__(self) -> int:
249249
"""
250250
return len(self._q)
251251

252-
async def receive(self) -> Optional[T]:
253-
"""Receive a message from the Broadcast channel.
254-
255-
Waits until there are messages available in the channel and returns
256-
them. If there are no remaining messages in the buffer and the channel
257-
is closed, returns `None` immediately.
258-
259-
If [into_peekable()][frequenz.channels.Receiver.into_peekable] is called
260-
on a broadcast `Receiver`, further calls to `receive`, will raise an
261-
`EOFError`.
252+
async def ready(self) -> None:
253+
"""Wait until the receiver is ready with a value.
262254
263255
Raises:
264-
EOFError: when the receiver has been converted into a `Peekable`.
265-
266-
Returns:
267-
`None`, if the channel is closed, a message otherwise.
256+
EOFError: if this receiver is no longer active.
257+
ChannelClosedError: if the underlying channel is closed.
268258
"""
269259
if not self._active:
270260
raise EOFError("This receiver is no longer active.")
271261

262+
# Use a while loop here, to handle spurious wakeups of condition variables.
263+
#
264+
# The condition also makes sure that if there are already messages ready to be
265+
# consumed, then we return immediately.
272266
while len(self._q) == 0:
273267
if self._chan.closed:
274-
return None
268+
raise ChannelClosedError()
275269
async with self._chan.recv_cv:
276270
await self._chan.recv_cv.wait()
271+
272+
def consume(self) -> T:
273+
"""Return the latest value once `ready` is complete.
274+
275+
Returns:
276+
The next value that was received.
277+
"""
278+
assert self._q, "calls to `consume()` must be follow a call to `ready()`"
277279
ret = self._q.popleft()
278280
return ret
279281

src/frequenz/channels/merge.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import asyncio
77
from collections import deque
8-
from typing import Any, Deque, Optional, Set
8+
from typing import Any, Deque, Set
99

10-
from frequenz.channels.base_classes import Receiver, T
10+
from frequenz.channels.base_classes import ChannelClosedError, Receiver, T
1111

1212

1313
class Merge(Receiver[T]):
@@ -34,7 +34,7 @@ def __init__(self, *args: Receiver[T]) -> None:
3434
"""
3535
self._receivers = {str(id): recv for id, recv in enumerate(args)}
3636
self._pending: Set[asyncio.Task[Any]] = {
37-
asyncio.create_task(recv.receive(), name=name)
37+
asyncio.create_task(recv.__anext__(), name=name)
3838
for name, recv in self._receivers.items()
3939
}
4040
self._results: Deque[T] = deque(maxlen=len(self._receivers))
@@ -44,31 +44,45 @@ def __del__(self) -> None:
4444
for task in self._pending:
4545
task.cancel()
4646

47-
async def receive(self) -> Optional[T]:
48-
"""Wait until there's a message in any of the channels.
47+
async def ready(self) -> None:
48+
"""Wait until the receiver is ready with a value.
4949
50-
Returns:
51-
The next message that was received, or `None`, if all channels have
52-
closed.
50+
Raises:
51+
ChannelClosedError: if the underlying channel is closed.
5352
"""
5453
# we use a while loop to continue to wait for new data, in case the
5554
# previous `wait` completed because a channel was closed.
5655
while True:
56+
# if there are messages waiting to be consumed, return immediately.
5757
if len(self._results) > 0:
58-
return self._results.popleft()
58+
return
5959

6060
if len(self._pending) == 0:
61-
return None
61+
raise ChannelClosedError()
6262
done, self._pending = await asyncio.wait(
6363
self._pending, return_when=asyncio.FIRST_COMPLETED
6464
)
6565
for item in done:
6666
name = item.get_name()
67-
result = item.result()
6867
# if channel is closed, don't add a task for it again.
69-
if result is None:
68+
if isinstance(item.exception(), StopAsyncIteration):
7069
continue
70+
result = item.result()
7171
self._results.append(result)
7272
self._pending.add(
73-
asyncio.create_task(self._receivers[name].receive(), name=name)
73+
# pylint: disable=unnecessary-dunder-call
74+
asyncio.create_task(self._receivers[name].__anext__(), name=name)
7475
)
76+
77+
def consume(self) -> T:
78+
"""Return the latest value once `ready` is complete.
79+
80+
Raises:
81+
EOFError: When called before a call to `ready()` finishes.
82+
83+
Returns:
84+
The next value that was received.
85+
"""
86+
assert self._results, "calls to `consume()` must be follow a call to `ready()`"
87+
88+
return self._results.popleft()

0 commit comments

Comments
 (0)