Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,31 @@ async def _move_elems_to_channel(
try:
await agen.aclose()
except BaseExceptionGroup as exceptions:
_, narrowed_exceptions = exceptions.split(GeneratorExit)
removed, narrowed_exceptions = exceptions.split(GeneratorExit)

# TODO: extract a helper to flatten exception groups
removed_exceptions = [removed]
for e in removed_exceptions:
if isinstance(e, BaseExceptionGroup):
removed_exceptions.extend(e.exceptions) # noqa: B909

if (
len(
[
e
for e in removed_exceptions
if isinstance(e, GeneratorExit)
]
)
> 1
):
exc = AssertionError("More than one GeneratorExit found.")
if narrowed_exceptions is None:
narrowed_exceptions = exceptions.derive([exc])
else:
narrowed_exceptions = narrowed_exceptions.derive(
[*narrowed_exceptions.exceptions, exc]
)
if narrowed_exceptions is not None:
raise narrowed_exceptions from None

Expand Down
23 changes: 21 additions & 2 deletions src/trio/_tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..testing import Matcher, RaisesGroup, assert_checkpoints, wait_all_tasks_blocked

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -634,7 +634,7 @@ async def agen() -> AsyncGenerator[None]:
async with trio.open_nursery():
yield
except BaseException as e:
assert isinstance(e, BaseExceptionGroup) # noqa: PT017 # we reraise
assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017
raise

async with agen() as g:
Expand Down Expand Up @@ -670,3 +670,22 @@ async def agen() -> AsyncGenerator[None]:
async with agen() as g:
async for _ in g:
break


async def test_as_safe_channel_swallowing_extra_exceptions() -> None:
async def wait_then_raise() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
raise GeneratorExit from None

@as_safe_channel
async def agen() -> AsyncGenerator[None]:
async with trio.open_nursery() as nursery:
nursery.start_soon(wait_then_raise)
yield

with pytest.RaisesGroup(AssertionError):
async with agen() as g:
async for _ in g:
break
Loading