Skip to content

Commit e0f4b12

Browse files
committed
Add tests for all Receiver.close() implementations
Signed-off-by: Sahas Subramanian <[email protected]>
1 parent cea6cfc commit e0f4b12

File tree

6 files changed

+276
-0
lines changed

6 files changed

+276
-0
lines changed

tests/test_anycast.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,30 @@ async def test_anycast_filter() -> None:
217217

218218
assert (await receiver.receive()) == 12
219219
assert (await receiver.receive()) == 15
220+
221+
222+
async def test_anycast_close_receiver() -> None:
223+
"""Ensure closing a receiver stops the receiver."""
224+
chan = Anycast[int](name="input-chan")
225+
sender = chan.new_sender()
226+
227+
receiver_1 = chan.new_receiver()
228+
receiver_2 = chan.new_receiver()
229+
230+
await sender.send(1)
231+
232+
assert (await receiver_1.receive()) == 1
233+
234+
receiver_1.close()
235+
236+
await sender.send(2)
237+
238+
with pytest.raises(ReceiverStoppedError):
239+
_ = await receiver_1.receive()
240+
241+
assert (await receiver_2.receive()) == 2
242+
243+
receiver_2.close()
244+
245+
with pytest.raises(ReceiverStoppedError):
246+
_ = await receiver_2.receive()

tests/test_broadcast.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,42 @@ async def test_broadcast_map() -> None:
232232
assert (await receiver.receive()) is True
233233

234234

235+
async def test_broadcast_map_close_receiver() -> None:
236+
"""Ensure closing a map stops the receiver."""
237+
chan = Broadcast[int](name="input-chan")
238+
sender = chan.new_sender()
239+
240+
receiver_1 = chan.new_receiver()
241+
receiver_2 = chan.new_receiver()
242+
plus_100_rx = receiver_1.map(lambda num: num + 100)
243+
244+
await sender.send(1)
245+
246+
assert (await plus_100_rx.receive()) == 101
247+
assert (await receiver_2.receive()) == 1
248+
249+
plus_100_rx.close()
250+
251+
await sender.send(2)
252+
253+
with pytest.raises(ReceiverStoppedError):
254+
_ = await plus_100_rx.receive()
255+
256+
with pytest.raises(ReceiverStoppedError):
257+
_ = await receiver_1.receive()
258+
259+
assert (await receiver_2.receive()) == 2
260+
261+
await sender.send(3)
262+
263+
assert (await receiver_2.receive()) == 3
264+
265+
receiver_2.close()
266+
267+
with pytest.raises(ReceiverStoppedError):
268+
_ = await receiver_2.receive()
269+
270+
235271
async def test_broadcast_filter() -> None:
236272
"""Ensure filter keeps only the messages that pass the filter."""
237273
chan = Broadcast[int](name="input-chan")
@@ -249,6 +285,43 @@ async def test_broadcast_filter() -> None:
249285
assert (await receiver.receive()) == 15
250286

251287

288+
async def test_broadcast_filter_close_receiver() -> None:
289+
"""Ensure closing a filter stops the receiver."""
290+
chan = Broadcast[int](name="input-chan")
291+
sender = chan.new_sender()
292+
293+
receiver_1 = chan.new_receiver()
294+
receiver_2 = chan.new_receiver()
295+
296+
gt_10_rx = receiver_1.filter(lambda num: num > 10)
297+
298+
await sender.send(1)
299+
assert (await receiver_2.receive()) == 1
300+
301+
await sender.send(100)
302+
assert (await gt_10_rx.receive()) == 100
303+
assert (await receiver_2.receive()) == 100
304+
305+
gt_10_rx.close()
306+
307+
await sender.send(2)
308+
309+
with pytest.raises(ReceiverStoppedError):
310+
_ = await gt_10_rx.receive()
311+
with pytest.raises(ReceiverStoppedError):
312+
_ = await receiver_1.receive()
313+
314+
assert (await receiver_2.receive()) == 2
315+
316+
await sender.send(3)
317+
assert (await receiver_2.receive()) == 3
318+
319+
receiver_2.close()
320+
321+
with pytest.raises(ReceiverStoppedError):
322+
_ = await receiver_2.receive()
323+
324+
252325
async def test_broadcast_filter_type_guard() -> None:
253326
"""Ensure filter type guard works."""
254327
chan = Broadcast[int | str](name="input-chan")
@@ -320,3 +393,35 @@ class Narrower(Actual):
320393

321394
await sender.send(Narrower(10))
322395
assert (await receiver.receive()).value == 10
396+
397+
398+
async def test_broadcast_close_receiver() -> None:
399+
"""Ensure closing a receiver stops the receiver."""
400+
chan = Broadcast[int](name="input-chan")
401+
sender = chan.new_sender()
402+
403+
receiver_1 = chan.new_receiver()
404+
receiver_2 = chan.new_receiver()
405+
406+
await sender.send(1)
407+
408+
assert (await receiver_1.receive()) == 1
409+
assert (await receiver_2.receive()) == 1
410+
411+
receiver_1.close()
412+
413+
await sender.send(2)
414+
415+
with pytest.raises(ReceiverStoppedError):
416+
_ = await receiver_1.receive()
417+
418+
assert (await receiver_2.receive()) == 2
419+
420+
await sender.send(3)
421+
422+
assert (await receiver_2.receive()) == 3
423+
424+
receiver_2.close()
425+
426+
with pytest.raises(ReceiverStoppedError):
427+
_ = await receiver_2.receive()

tests/test_event.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,51 @@ async def wait_for_event() -> None:
5757
assert not event.is_set
5858

5959
await event_task
60+
61+
62+
async def test_event_close_receiver() -> None:
63+
"""Ensure that closing an event stops the receiver."""
64+
event = Event()
65+
assert not event.is_set
66+
assert not event.is_stopped
67+
68+
is_ready = False
69+
70+
async def wait_for_event() -> None:
71+
nonlocal is_ready
72+
await event.ready()
73+
is_ready = True
74+
75+
event_task = _asyncio.create_task(wait_for_event())
76+
77+
await _asyncio.sleep(0) # Yield so the wait_for_event task can run.
78+
79+
assert not is_ready
80+
assert not event.is_set
81+
assert not event.is_stopped
82+
83+
event.set()
84+
85+
await _asyncio.sleep(0) # Yield so the wait_for_event task can run.
86+
assert is_ready
87+
assert event.is_set
88+
assert not event.is_stopped
89+
90+
event.consume()
91+
assert not event.is_set
92+
assert not event.is_stopped
93+
assert event_task.done()
94+
assert event_task.result() is None
95+
assert not event_task.cancelled()
96+
97+
event.close()
98+
assert not event.is_set
99+
assert event.is_stopped
100+
101+
await event.ready()
102+
with _pytest.raises(ReceiverStoppedError):
103+
event.consume()
104+
assert event.is_stopped
105+
assert not event.is_set
106+
107+
await event_task

tests/test_file_watcher_integration.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,40 @@ async def test_file_watcher_exit_iterator(tmp_path: pathlib.Path) -> None:
150150
file_watcher.consume()
151151

152152
assert number_of_writes == expected_number_of_writes
153+
154+
155+
@pytest.mark.integration
156+
async def test_file_watcher_close_receiver(tmp_path: pathlib.Path) -> None:
157+
"""Ensure closing the file watcher stops the receiver.
158+
159+
Args:
160+
tmp_path: A tmp directory to run the file watcher on. Created by pytest.
161+
"""
162+
filename = tmp_path / "test-file"
163+
164+
number_of_writes = 0
165+
expected_number_of_writes = 3
166+
167+
file_watcher = FileWatcher(
168+
paths=[str(tmp_path)],
169+
force_polling=True,
170+
polling_interval=timedelta(seconds=0.05),
171+
)
172+
timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift())
173+
174+
async for selected in select(file_watcher, timer):
175+
if selected_from(selected, timer):
176+
filename.write_text(f"{selected.message}")
177+
elif selected_from(selected, file_watcher):
178+
number_of_writes += 1
179+
if number_of_writes == expected_number_of_writes:
180+
file_watcher.close()
181+
break
182+
183+
ready = await file_watcher.ready()
184+
assert ready is False
185+
186+
with pytest.raises(ReceiverStoppedError):
187+
file_watcher.consume()
188+
189+
assert number_of_writes == expected_number_of_writes

tests/test_merge_integration.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pytest
99

1010
from frequenz.channels import Anycast, Sender, merge
11+
from frequenz.channels._broadcast import Broadcast
12+
from frequenz.channels._receiver import ReceiverStoppedError
1113

1214

1315
@pytest.mark.integration
@@ -39,3 +41,44 @@ async def send(ch1: Sender[int], ch2: Sender[int]) -> None:
3941
# succession.
4042
assert set(results[idx : idx + 2]) == {ctr + 1, ctr + 101}
4143
assert results[-1] == 1000
44+
45+
46+
async def test_merge_close_receiver() -> None:
47+
"""Ensure merge() closes when a receiver is closed."""
48+
chan1 = Broadcast[int](name="chan1")
49+
chan2 = Broadcast[int](name="chan2")
50+
51+
async def send(ch1: Sender[int], ch2: Sender[int]) -> None:
52+
for ctr in range(5):
53+
await ch1.send(ctr + 1)
54+
await ch2.send(ctr + 101)
55+
await chan1.close()
56+
await chan2.close()
57+
58+
rx1 = chan1.new_receiver()
59+
rx2 = chan2.new_receiver()
60+
closing_merge = merge(rx1, rx2)
61+
prx1 = chan1.new_receiver()
62+
prx2 = chan2.new_receiver()
63+
completing_merge = merge(prx1, prx2)
64+
65+
senders = asyncio.create_task(send(chan1.new_sender(), chan2.new_sender()))
66+
67+
results: list[int] = []
68+
async for item in closing_merge:
69+
results.append(item)
70+
if item == 3:
71+
closing_merge.close()
72+
await senders
73+
assert set(results) == {1, 101, 2, 102, 3, 103}
74+
75+
with pytest.raises(ReceiverStoppedError):
76+
_ = await rx1.receive()
77+
78+
with pytest.raises(ReceiverStoppedError):
79+
_ = await rx2.receive()
80+
81+
comp_results: set[int] = set()
82+
async for item in completing_merge:
83+
comp_results.add(item)
84+
assert comp_results == {1, 101, 2, 102, 3, 103, 4, 104, 5, 105}

tests/test_timer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pytest
1414
from hypothesis import strategies as st
1515

16+
from frequenz.channels import ReceiverStoppedError
1617
from frequenz.channels.timer import (
1718
SkipMissedAndDrift,
1819
SkipMissedAndResync,
@@ -331,6 +332,21 @@ async def test_timer_construction_wrong_args() -> None:
331332
)
332333

333334

335+
async def test_timer_close_receiver() -> None:
336+
"""Test the autostart of a periodic timer."""
337+
event_loop = asyncio.get_running_loop()
338+
339+
timer = Timer(timedelta(seconds=1.0), TriggerAllMissed())
340+
341+
drift = await timer.receive()
342+
assert drift == pytest.approx(timedelta(seconds=0.0))
343+
assert event_loop.time() == pytest.approx(1.0)
344+
345+
timer.close()
346+
with pytest.raises(ReceiverStoppedError):
347+
await timer.receive()
348+
349+
334350
async def test_timer_autostart() -> None:
335351
"""Test the autostart of a periodic timer."""
336352
event_loop = asyncio.get_running_loop()

0 commit comments

Comments
 (0)