@@ -232,6 +232,42 @@ async def test_broadcast_map() -> None:
232232 assert (await receiver .receive ()) is True
233233
234234
235+ async def test_broadcast_map_close_receiver () -> None :
236+ """Ensure closing a map stops the receiver."""
237+ chan = Broadcast [int ](name = "input-chan" )
238+ sender = chan .new_sender ()
239+
240+ receiver_1 = chan .new_receiver ()
241+ receiver_2 = chan .new_receiver ()
242+ plus_100_rx = receiver_1 .map (lambda num : num + 100 )
243+
244+ await sender .send (1 )
245+
246+ assert (await plus_100_rx .receive ()) == 101
247+ assert (await receiver_2 .receive ()) == 1
248+
249+ plus_100_rx .close ()
250+
251+ await sender .send (2 )
252+
253+ with pytest .raises (ReceiverStoppedError ):
254+ _ = await plus_100_rx .receive ()
255+
256+ with pytest .raises (ReceiverStoppedError ):
257+ _ = await receiver_1 .receive ()
258+
259+ assert (await receiver_2 .receive ()) == 2
260+
261+ await sender .send (3 )
262+
263+ assert (await receiver_2 .receive ()) == 3
264+
265+ receiver_2 .close ()
266+
267+ with pytest .raises (ReceiverStoppedError ):
268+ _ = await receiver_2 .receive ()
269+
270+
235271async def test_broadcast_filter () -> None :
236272 """Ensure filter keeps only the messages that pass the filter."""
237273 chan = Broadcast [int ](name = "input-chan" )
@@ -249,6 +285,43 @@ async def test_broadcast_filter() -> None:
249285 assert (await receiver .receive ()) == 15
250286
251287
288+ async def test_broadcast_filter_close_receiver () -> None :
289+ """Ensure closing a filter stops the receiver."""
290+ chan = Broadcast [int ](name = "input-chan" )
291+ sender = chan .new_sender ()
292+
293+ receiver_1 = chan .new_receiver ()
294+ receiver_2 = chan .new_receiver ()
295+
296+ gt_10_rx = receiver_1 .filter (lambda num : num > 10 )
297+
298+ await sender .send (1 )
299+ assert (await receiver_2 .receive ()) == 1
300+
301+ await sender .send (100 )
302+ assert (await gt_10_rx .receive ()) == 100
303+ assert (await receiver_2 .receive ()) == 100
304+
305+ gt_10_rx .close ()
306+
307+ await sender .send (2 )
308+
309+ with pytest .raises (ReceiverStoppedError ):
310+ _ = await gt_10_rx .receive ()
311+ with pytest .raises (ReceiverStoppedError ):
312+ _ = await receiver_1 .receive ()
313+
314+ assert (await receiver_2 .receive ()) == 2
315+
316+ await sender .send (3 )
317+ assert (await receiver_2 .receive ()) == 3
318+
319+ receiver_2 .close ()
320+
321+ with pytest .raises (ReceiverStoppedError ):
322+ _ = await receiver_2 .receive ()
323+
324+
252325async def test_broadcast_filter_type_guard () -> None :
253326 """Ensure filter type guard works."""
254327 chan = Broadcast [int | str ](name = "input-chan" )
@@ -320,3 +393,35 @@ class Narrower(Actual):
320393
321394 await sender .send (Narrower (10 ))
322395 assert (await receiver .receive ()).value == 10
396+
397+
398+ async def test_broadcast_close_receiver () -> None :
399+ """Ensure closing a receiver stops the receiver."""
400+ chan = Broadcast [int ](name = "input-chan" )
401+ sender = chan .new_sender ()
402+
403+ receiver_1 = chan .new_receiver ()
404+ receiver_2 = chan .new_receiver ()
405+
406+ await sender .send (1 )
407+
408+ assert (await receiver_1 .receive ()) == 1
409+ assert (await receiver_2 .receive ()) == 1
410+
411+ receiver_1 .close ()
412+
413+ await sender .send (2 )
414+
415+ with pytest .raises (ReceiverStoppedError ):
416+ _ = await receiver_1 .receive ()
417+
418+ assert (await receiver_2 .receive ()) == 2
419+
420+ await sender .send (3 )
421+
422+ assert (await receiver_2 .receive ()) == 3
423+
424+ receiver_2 .close ()
425+
426+ with pytest .raises (ReceiverStoppedError ):
427+ _ = await receiver_2 .receive ()
0 commit comments