Skip to content

Commit 0f975c3

Browse files
authored
Merge pull request #3331 from Zac-HD/close-on-broken-resource
fix broken-channel bug in `@trio.as_safe_channel`
2 parents 0aa5ee3 + 4407451 commit 0f975c3

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

newsfragments/3331.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed a bug where iterating over an ``@as_safe_channel``-derived ``ReceiveChannel``
2+
would raise `~trio.BrokenResourceError` if the channel was closed by another task.
3+
It now shuts down cleanly.

src/trio/_channel.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import trio
1818

1919
from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
20-
from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
20+
from ._core import Abort, BrokenResourceError, RaiseCancelT, Task, enable_ki_protection
2121
from ._util import (
2222
MultipleExceptionError,
2323
NoPublicConstructor,
@@ -577,12 +577,19 @@ async def _move_elems_to_channel(
577577
while True:
578578
# wait for receiver to call next on the aiter
579579
await send_semaphore.acquire()
580+
if not send_chan._state.open_receive_channels:
581+
# skip the possibly-expensive computation in the generator,
582+
# if we know it will be impossible to send the result.
583+
break
580584
try:
581585
value = await agen.__anext__()
582586
except StopAsyncIteration:
583587
return
584588
# Send the value to the channel
585-
await send_chan.send(value)
589+
try:
590+
await send_chan.send(value)
591+
except BrokenResourceError:
592+
break # closed since we checked above
586593
finally:
587594
# work around `.aclose()` not suppressing GeneratorExit in an
588595
# ExceptionGroup:

src/trio/_tests/test_channel.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ async def test_as_safe_channel_broken_resource() -> None:
434434
@as_safe_channel
435435
async def agen() -> AsyncGenerator[int]:
436436
yield 1
437-
yield 2
437+
yield 2 # pragma: no cover
438438

439439
async with agen() as recv_chan:
440440
assert await recv_chan.__anext__() == 1
@@ -695,3 +695,47 @@ async def agen(ex: type[BaseException]) -> AsyncGenerator[None]:
695695
async with agen(ValueError) as g:
696696
async for _ in g:
697697
break
698+
699+
700+
async def test_as_safe_channel_close_between_iteration() -> None:
701+
@as_safe_channel
702+
async def agen() -> AsyncGenerator[None]:
703+
while True:
704+
yield
705+
706+
async with agen() as chan, trio.open_nursery() as nursery:
707+
708+
async def close_channel() -> None:
709+
await trio.lowlevel.checkpoint()
710+
await chan.aclose()
711+
712+
nursery.start_soon(close_channel)
713+
with pytest.raises(trio.ClosedResourceError):
714+
async for _ in chan:
715+
pass
716+
717+
718+
async def test_as_safe_channel_close_before_iteration() -> None:
719+
@as_safe_channel
720+
async def agen() -> AsyncGenerator[None]:
721+
raise AssertionError("should be unreachable") # pragma: no cover
722+
yield # pragma: no cover
723+
724+
async with agen() as chan:
725+
await chan.aclose()
726+
with pytest.raises(trio.ClosedResourceError):
727+
await chan.receive()
728+
729+
730+
async def test_as_safe_channel_close_during_iteration() -> None:
731+
@as_safe_channel
732+
async def agen() -> AsyncGenerator[None]:
733+
await chan.aclose()
734+
while True:
735+
yield
736+
737+
for _ in range(10): # 20% missed-alarm rate, so run ten times
738+
async with agen() as chan:
739+
with pytest.raises(trio.ClosedResourceError):
740+
async for _ in chan:
741+
pass

0 commit comments

Comments
 (0)