Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newsfragments/3324.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Avoid having `trio.as_safe_channel` raise if closing the generator wrapped
`GeneratorExit` in a `BaseExceptionGroup`.
31 changes: 29 additions & 2 deletions src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ async def _move_elems_to_channel(
# `async with send_chan` will eat exceptions,
# see https://github.com/python-trio/trio/issues/1559
with send_chan:
# replace try-finally with contextlib.aclosing once python39 is
# dropped:
try:
task_status.started()
while True:
Expand All @@ -582,7 +584,32 @@ async def _move_elems_to_channel(
# Send the value to the channel
await send_chan.send(value)
finally:
# replace try-finally with contextlib.aclosing once python39 is dropped
await agen.aclose()
# work around `.aclose()` not suppressing GeneratorExit in an
# ExceptionGroup:
# TODO: make an issue on CPython about this
try:
await agen.aclose()
except BaseExceptionGroup as exceptions:
removed, narrowed_exceptions = exceptions.split(GeneratorExit)

# TODO: extract a helper to flatten exception groups
removed_exceptions: list[BaseException | None] = [removed]
genexits_seen = 0
for e in removed_exceptions:
if isinstance(e, BaseExceptionGroup):
removed_exceptions.extend(e.exceptions) # noqa: B909
else:
genexits_seen += 1

if genexits_seen > 1:
exc = AssertionError("More than one GeneratorExit found.")
if narrowed_exceptions is None:
narrowed_exceptions = exceptions.derive([exc])
else:
narrowed_exceptions = narrowed_exceptions.derive(
[*narrowed_exceptions.exceptions, exc]
)
if narrowed_exceptions is not None:
raise narrowed_exceptions from None

return context_manager
70 changes: 70 additions & 0 deletions src/trio/_tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,73 @@ async def agen(events: list[str]) -> AsyncGenerator[None]:
events.append("body cancel")
raise
assert events == ["body cancel", "agen cancel"]


async def test_as_safe_channel_genexit_exception_group() -> None:
@as_safe_channel
async def agen() -> AsyncGenerator[None]:
try:
async with trio.open_nursery():
yield
except BaseException as e:
assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017
raise

async with agen() as g:
async for _ in g:
break


async def test_as_safe_channel_does_not_suppress_nested_genexit() -> None:
@as_safe_channel
async def agen() -> AsyncGenerator[None]:
yield

with pytest.RaisesGroup(GeneratorExit):
async with agen() as g, trio.open_nursery():
await g.receive() # this is for coverage reasons
raise GeneratorExit


async def test_as_safe_channel_genexit_filter() -> None:
async def wait_then_raise() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
raise ValueError from None

@as_safe_channel
async def agen() -> AsyncGenerator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(wait_then_raise)
yield

with pytest.RaisesGroup(ValueError):
async with agen() as g:
async for _ in g:
break


async def test_as_safe_channel_swallowing_extra_exceptions() -> None:
async def wait_then_raise(ex: type[BaseException]) -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
raise ex from None

@as_safe_channel
async def agen(ex: type[BaseException]) -> AsyncGenerator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(wait_then_raise, ex)
nursery.start_soon(wait_then_raise, GeneratorExit)
yield

with pytest.RaisesGroup(AssertionError):
async with agen(GeneratorExit) as g:
async for _ in g:
break

with pytest.RaisesGroup(ValueError, AssertionError):
async with agen(ValueError) as g:
async for _ in g:
break
Loading