Skip to content

Commit 18000b8

Browse files
committed
Raise ChannelClosedError to signal that a channel is closed
... instead of using `None` Signed-off-by: Sahas Subramanian <[email protected]>
1 parent 9f554c9 commit 18000b8

File tree

7 files changed

+56
-21
lines changed

7 files changed

+56
-21
lines changed

benchmarks/benchmark_anycast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ async def benchmark_anycast(
5151
recv_trackers = [0]
5252

5353
async def update_tracker_on_receive(chan: Receiver[int]) -> None:
54-
while True:
55-
msg = await chan.receive()
56-
if msg is None:
57-
return
54+
async for _ in chan:
5855
recv_trackers[0] += 1
5956

6057
receivers = []

benchmarks/benchmark_broadcast.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ async def benchmark_broadcast(
6969
recv_trackers = [0]
7070

7171
async def update_tracker_on_receive(chan: Receiver[int]) -> None:
72-
while True:
73-
msg = await chan.receive()
74-
if msg is None:
75-
return
72+
async for _ in chan:
7673
recv_trackers[0] += 1
7774

7875
receivers = []

src/frequenz/channels/base_classes.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,41 @@
66
from __future__ import annotations
77

88
from abc import ABC, abstractmethod
9-
from typing import Callable, Generic, Optional, TypeVar
9+
from typing import Any, Callable, Generic, Optional, TypeVar
1010

1111
T = TypeVar("T")
1212
U = TypeVar("U")
1313

1414

15+
class ChannelError(RuntimeError):
16+
"""Base channel error.
17+
18+
All exceptions generated by channels inherit from this exception.
19+
"""
20+
21+
def __init__(self, message: Any, channel: Any = None):
22+
"""Create a ChannelError instance.
23+
24+
Args:
25+
message: An error message.
26+
channel: A reference to the channel that encountered the error.
27+
"""
28+
super().__init__(message)
29+
self.channel: Any = channel
30+
31+
32+
class ChannelClosedError(ChannelError):
33+
"""Error raised when trying to operate on a closed channel."""
34+
35+
def __init__(self, channel: Any = None):
36+
"""Create a `ChannelClosedError` instance.
37+
38+
Args:
39+
channel: A reference to the channel that was closed.
40+
"""
41+
super().__init__(f"Channel {channel} was closed", channel)
42+
43+
1544
class Sender(ABC, Generic[T]):
1645
"""A channel Sender."""
1746

@@ -50,16 +79,19 @@ def __aiter__(self) -> Receiver[T]:
5079
"""
5180
return self
5281

53-
async def receive(self) -> Optional[T]:
82+
async def receive(self) -> T:
5483
"""Receive a message from the channel.
5584
85+
Raises:
86+
ChannelClosedError: if the underlying channel is closed.
87+
5688
Returns:
5789
The received message.
5890
"""
5991
try:
6092
received = await self.__anext__() # pylint: disable=unnecessary-dunder-call
61-
except StopAsyncIteration:
62-
return None
93+
except StopAsyncIteration as exc:
94+
raise ChannelClosedError() from exc
6395
return received
6496

6597
def map(self, call: Callable[[T], U]) -> Receiver[U]:

tests/test_anycast.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import asyncio
77

8+
import pytest
9+
810
from frequenz.channels import Anycast, Receiver, Sender
11+
from frequenz.channels.base_classes import ChannelClosedError
912

1013

1114
async def test_anycast() -> None:
@@ -29,8 +32,9 @@ async def send_msg(chan: Sender[int]) -> None:
2932

3033
async def update_tracker_on_receive(receiver_id: int, chan: Receiver[int]) -> None:
3134
while True:
32-
msg = await chan.receive()
33-
if msg is None:
35+
try:
36+
msg = await chan.receive()
37+
except ChannelClosedError:
3438
return
3539
recv_trackers[receiver_id] += msg
3640
# without the sleep, decomissioning receivers temporarily, all
@@ -56,7 +60,8 @@ async def update_tracker_on_receive(receiver_id: int, chan: Receiver[int]) -> No
5660
await receivers_runs
5761

5862
assert await after_close_sender.send(5) is False
59-
assert await after_close_receiver.receive() is None
63+
with pytest.raises(ChannelClosedError):
64+
await after_close_receiver.receive()
6065

6166
actual_sum = 0
6267
for ctr in recv_trackers:
@@ -79,7 +84,8 @@ async def test_anycast_after_close() -> None:
7984

8085
assert await sender.send(5) is False
8186
assert await receiver.receive() == 2
82-
assert await receiver.receive() is None
87+
with pytest.raises(ChannelClosedError):
88+
await receiver.receive()
8389

8490

8591
async def test_anycast_full() -> None:

tests/test_broadcast.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from frequenz.channels import Broadcast, Receiver, Sender
12+
from frequenz.channels.base_classes import ChannelClosedError
1213

1314

1415
async def test_broadcast() -> None:
@@ -32,8 +33,9 @@ async def send_msg(chan: Sender[int]) -> None:
3233

3334
async def update_tracker_on_receive(receiver_id: int, chan: Receiver[int]) -> None:
3435
while True:
35-
msg = await chan.receive()
36-
if msg is None:
36+
try:
37+
msg = await chan.receive()
38+
except ChannelClosedError:
3739
return
3840
recv_trackers[receiver_id] += msg
3941

@@ -68,7 +70,8 @@ async def test_broadcast_after_close() -> None:
6870
await bcast.close()
6971

7072
assert await sender.send(5) is False
71-
assert await receiver.receive() is None
73+
with pytest.raises(ChannelClosedError):
74+
await receiver.receive()
7275

7376

7477
async def test_broadcast_overflow() -> None:

tests/test_merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def send(ch1: Sender[int], ch2: Sender[int]) -> None:
2626

2727
merge = Merge(chan1.get_receiver(), chan2.get_receiver())
2828
results: List[int] = []
29-
while item := await merge.receive():
29+
async for item in merge:
3030
results.append(item)
3131
await senders
3232
for ctr in range(5):

tests/test_mergenamed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def send(ch1: Sender[int], ch2: Sender[int]) -> None:
2727

2828
merge = MergeNamed(**recvs)
2929
results: List[Tuple[str, int]] = []
30-
while item := await merge.receive():
30+
async for item in merge:
3131
results.append(item)
3232
await senders
3333
for ctr in range(5):

0 commit comments

Comments
 (0)