From 4407451942c5d09e0862058ee3349bd2fd536cf8 Mon Sep 17 00:00:00 2001 From: Zac Hatfield-Dodds Date: Tue, 16 Sep 2025 21:30:52 -0700 Subject: [PATCH] fix broken-channel bug --- newsfragments/3331.bugfix.rst | 3 +++ src/trio/_channel.py | 11 ++++++-- src/trio/_tests/test_channel.py | 46 ++++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 newsfragments/3331.bugfix.rst diff --git a/newsfragments/3331.bugfix.rst b/newsfragments/3331.bugfix.rst new file mode 100644 index 000000000..da30bc9d5 --- /dev/null +++ b/newsfragments/3331.bugfix.rst @@ -0,0 +1,3 @@ +Fixed a bug where iterating over an ``@as_safe_channel``-derived ``ReceiveChannel`` +would raise `~trio.BrokenResourceError` if the channel was closed by another task. +It now shuts down cleanly. diff --git a/src/trio/_channel.py b/src/trio/_channel.py index 2afca9d7c..1398e766e 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -17,7 +17,7 @@ import trio from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T -from ._core import Abort, RaiseCancelT, Task, enable_ki_protection +from ._core import Abort, BrokenResourceError, RaiseCancelT, Task, enable_ki_protection from ._util import ( MultipleExceptionError, NoPublicConstructor, @@ -577,12 +577,19 @@ async def _move_elems_to_channel( while True: # wait for receiver to call next on the aiter await send_semaphore.acquire() + if not send_chan._state.open_receive_channels: + # skip the possibly-expensive computation in the generator, + # if we know it will be impossible to send the result. + break try: value = await agen.__anext__() except StopAsyncIteration: return # Send the value to the channel - await send_chan.send(value) + try: + await send_chan.send(value) + except BrokenResourceError: + break # closed since we checked above finally: # work around `.aclose()` not suppressing GeneratorExit in an # ExceptionGroup: diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index 85cd982a4..cfdf904fd 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -434,7 +434,7 @@ async def test_as_safe_channel_broken_resource() -> None: @as_safe_channel async def agen() -> AsyncGenerator[int]: yield 1 - yield 2 + yield 2 # pragma: no cover async with agen() as recv_chan: assert await recv_chan.__anext__() == 1 @@ -695,3 +695,47 @@ async def agen(ex: type[BaseException]) -> AsyncGenerator[None]: async with agen(ValueError) as g: async for _ in g: break + + +async def test_as_safe_channel_close_between_iteration() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + while True: + yield + + async with agen() as chan, trio.open_nursery() as nursery: + + async def close_channel() -> None: + await trio.lowlevel.checkpoint() + await chan.aclose() + + nursery.start_soon(close_channel) + with pytest.raises(trio.ClosedResourceError): + async for _ in chan: + pass + + +async def test_as_safe_channel_close_before_iteration() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + raise AssertionError("should be unreachable") # pragma: no cover + yield # pragma: no cover + + async with agen() as chan: + await chan.aclose() + with pytest.raises(trio.ClosedResourceError): + await chan.receive() + + +async def test_as_safe_channel_close_during_iteration() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + await chan.aclose() + while True: + yield + + for _ in range(10): # 20% missed-alarm rate, so run ten times + async with agen() as chan: + with pytest.raises(trio.ClosedResourceError): + async for _ in chan: + pass