Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newsfragments/3324.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Avoid having `trio.as_safe_channel` raise if closing the generator wrapped
`GeneratorExit` in a `BaseExceptionGroup`.
31 changes: 29 additions & 2 deletions src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ async def _move_elems_to_channel(
# `async with send_chan` will eat exceptions,
# see https://github.com/python-trio/trio/issues/1559
with send_chan:
# replace try-finally with contextlib.aclosing once python39 is
# dropped:
try:
task_status.started()
while True:
Expand All @@ -582,7 +584,32 @@ async def _move_elems_to_channel(
# Send the value to the channel
await send_chan.send(value)
finally:
# replace try-finally with contextlib.aclosing once python39 is dropped
await agen.aclose()
# work around `.aclose()` not suppressing GeneratorExit in an
# ExceptionGroup:
# TODO: make an issue on CPython about this
try:
await agen.aclose()
except BaseExceptionGroup as exceptions:
removed, narrowed_exceptions = exceptions.split(GeneratorExit)

# TODO: extract a helper to flatten exception groups
removed_exceptions: list[BaseException | None] = [removed]
genexits_seen = 0
for e in removed_exceptions:
if isinstance(e, BaseExceptionGroup):
removed_exceptions.extend(e.exceptions) # noqa: B909
else:
genexits_seen += 1

if genexits_seen > 1:
exc = AssertionError("More than one GeneratorExit found.")
if narrowed_exceptions is None:
narrowed_exceptions = exceptions.derive([exc])
else:
narrowed_exceptions = narrowed_exceptions.derive(
[*narrowed_exceptions.exceptions, exc]
)
if narrowed_exceptions is not None:
raise narrowed_exceptions from None

return context_manager
64 changes: 64 additions & 0 deletions src/trio/_tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,67 @@ async def agen(events: list[str]) -> AsyncGenerator[None]:
events.append("body cancel")
raise
assert events == ["body cancel", "agen cancel"]


async def test_as_safe_channel_genexit_exception_group() -> None:
@as_safe_channel
async def agen() -> AsyncGenerator[None]:
try:
async with trio.open_nursery():
yield
except BaseException as e:
assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017
raise

async with agen() as g:
async for _ in g:
break


async def test_as_safe_channel_does_not_suppress_nested_genexit() -> None:
@as_safe_channel
async def agen() -> AsyncGenerator[None]:
yield

with pytest.RaisesGroup(GeneratorExit):
async with agen() as g, trio.open_nursery():
await g.receive() # this is for coverage reasons
raise GeneratorExit


async def test_as_safe_channel_genexit_filter() -> None:
async def wait_then_raise() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
raise ValueError from None

@as_safe_channel
async def agen() -> AsyncGenerator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(wait_then_raise)
yield

with pytest.RaisesGroup(ValueError):
async with agen() as g:
async for _ in g:
break


async def test_as_safe_channel_swallowing_extra_exceptions() -> None:
async def wait_then_raise() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
raise GeneratorExit from None

@as_safe_channel
async def agen() -> AsyncGenerator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(wait_then_raise)
yield

with pytest.RaisesGroup(AssertionError):
async with agen() as g:
async for _ in g:
break
Loading