Skip to content

Commit 2565be8

Browse files
Add support/tests for None-value channels (#83)
2 parents d5e5f0a + caea4e0 commit 2565be8

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

src/frequenz/channels/_anycast.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from asyncio import Condition
99
from collections import deque
10-
from typing import Deque, Generic, Optional
10+
from typing import Deque, Generic, Type
1111

1212
from ._base_classes import Receiver as BaseReceiver
1313
from ._base_classes import Sender as BaseSender
@@ -151,6 +151,10 @@ async def send(self, msg: T) -> None:
151151
self._chan.recv_cv.notify(1)
152152

153153

154+
class _Empty:
155+
"""A sentinel value to indicate that a value has not been set."""
156+
157+
154158
class Receiver(BaseReceiver[T]):
155159
"""A receiver to receive messages from an Anycast channel.
156160
@@ -165,7 +169,7 @@ def __init__(self, chan: Anycast[T]) -> None:
165169
chan: A reference to the channel that this receiver belongs to.
166170
"""
167171
self._chan = chan
168-
self._next: Optional[T] = None
172+
self._next: T | Type[_Empty] = _Empty
169173

170174
async def ready(self) -> bool:
171175
"""Wait until the receiver is ready with a value or an error.
@@ -179,7 +183,7 @@ async def ready(self) -> bool:
179183
Whether the receiver is still active.
180184
"""
181185
# if a message is already ready, then return immediately.
182-
if self._next is not None:
186+
if self._next is not _Empty:
183187
return True
184188

185189
while len(self._chan.deque) == 0:
@@ -202,12 +206,15 @@ def consume(self) -> T:
202206
ReceiverStoppedError: if the receiver stopped producing messages.
203207
ReceiverError: if there is some problem with the receiver.
204208
"""
205-
if self._next is None and self._chan.closed:
209+
if self._next is _Empty and self._chan.closed:
206210
raise ReceiverStoppedError(self) from ChannelClosedError(self._chan)
207211

208212
assert (
209-
self._next is not None
213+
self._next is not _Empty
210214
), "`consume()` must be preceeded by a call to `ready()`"
211-
next_val = self._next
212-
self._next = None
215+
# mypy doesn't understand that the assert above ensures that self._next is not
216+
# _Sentinel. So we have to use a type ignore here.
217+
next_val: T = self._next # type: ignore[assignment]
218+
self._next = _Empty
219+
213220
return next_val

tests/test_anycast.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
"""Tests for the Channel implementation."""
55

6+
from __future__ import annotations
7+
68
import asyncio
79

810
import pytest
@@ -147,6 +149,23 @@ async def test_anycast_full() -> None:
147149
assert False
148150

149151

152+
async def test_anycast_none_values() -> None:
153+
"""Ensure None values can be sent and received."""
154+
acast: Anycast[int | None] = Anycast()
155+
156+
sender = acast.new_sender()
157+
receiver = acast.new_receiver()
158+
159+
await sender.send(5)
160+
assert await receiver.receive() == 5
161+
162+
await sender.send(None)
163+
assert await receiver.receive() is None
164+
165+
await sender.send(10)
166+
assert await receiver.receive() == 10
167+
168+
150169
async def test_anycast_async_iterator() -> None:
151170
"""Check that the anycast receiver works as an async iterator."""
152171
acast: Anycast[str] = Anycast()

tests/test_broadcast.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
"""Tests for the Broadcast implementation."""
55

6+
from __future__ import annotations
7+
68
import asyncio
79
from typing import Tuple
810

@@ -69,6 +71,23 @@ async def update_tracker_on_receive(receiver_id: int, recv: Receiver[int]) -> No
6971
assert actual_sum == expected_sum
7072

7173

74+
async def test_broadcast_none_values() -> None:
75+
"""Ensure None values can be sent and received."""
76+
bcast: Broadcast[int | None] = Broadcast("any_channel")
77+
78+
sender = bcast.new_sender()
79+
receiver = bcast.new_receiver()
80+
81+
await sender.send(5)
82+
assert await receiver.receive() == 5
83+
84+
await sender.send(None)
85+
assert await receiver.receive() is None
86+
87+
await sender.send(10)
88+
assert await receiver.receive() == 10
89+
90+
7291
async def test_broadcast_after_close() -> None:
7392
"""Ensure closed channels can't get new messages."""
7493
bcast: Broadcast[int] = Broadcast("meter_5")

0 commit comments

Comments
 (0)