Skip to content

Commit cc148cf

Browse files
authored
Suppress GeneratorExit from .aclose() for as_safe_channel (#3325)
1 parent fc352b2 commit cc148cf

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

newsfragments/3324.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Avoid having `trio.as_safe_channel` raise if closing the generator wrapped
2+
`GeneratorExit` in a `BaseExceptionGroup`.

src/trio/_channel.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ async def _move_elems_to_channel(
570570
# `async with send_chan` will eat exceptions,
571571
# see https://github.com/python-trio/trio/issues/1559
572572
with send_chan:
573+
# replace try-finally with contextlib.aclosing once python39 is
574+
# dropped:
573575
try:
574576
task_status.started()
575577
while True:
@@ -582,7 +584,32 @@ async def _move_elems_to_channel(
582584
# Send the value to the channel
583585
await send_chan.send(value)
584586
finally:
585-
# replace try-finally with contextlib.aclosing once python39 is dropped
586-
await agen.aclose()
587+
# work around `.aclose()` not suppressing GeneratorExit in an
588+
# ExceptionGroup:
589+
# TODO: make an issue on CPython about this
590+
try:
591+
await agen.aclose()
592+
except BaseExceptionGroup as exceptions:
593+
removed, narrowed_exceptions = exceptions.split(GeneratorExit)
594+
595+
# TODO: extract a helper to flatten exception groups
596+
removed_exceptions: list[BaseException | None] = [removed]
597+
genexits_seen = 0
598+
for e in removed_exceptions:
599+
if isinstance(e, BaseExceptionGroup):
600+
removed_exceptions.extend(e.exceptions) # noqa: B909
601+
else:
602+
genexits_seen += 1
603+
604+
if genexits_seen > 1:
605+
exc = AssertionError("More than one GeneratorExit found.")
606+
if narrowed_exceptions is None:
607+
narrowed_exceptions = exceptions.derive([exc])
608+
else:
609+
narrowed_exceptions = narrowed_exceptions.derive(
610+
[*narrowed_exceptions.exceptions, exc]
611+
)
612+
if narrowed_exceptions is not None:
613+
raise narrowed_exceptions from None
587614

588615
return context_manager

src/trio/_tests/test_channel.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,73 @@ async def agen(events: list[str]) -> AsyncGenerator[None]:
625625
events.append("body cancel")
626626
raise
627627
assert events == ["body cancel", "agen cancel"]
628+
629+
630+
async def test_as_safe_channel_genexit_exception_group() -> None:
631+
@as_safe_channel
632+
async def agen() -> AsyncGenerator[None]:
633+
try:
634+
async with trio.open_nursery():
635+
yield
636+
except BaseException as e:
637+
assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017
638+
raise
639+
640+
async with agen() as g:
641+
async for _ in g:
642+
break
643+
644+
645+
async def test_as_safe_channel_does_not_suppress_nested_genexit() -> None:
646+
@as_safe_channel
647+
async def agen() -> AsyncGenerator[None]:
648+
yield
649+
650+
with pytest.RaisesGroup(GeneratorExit):
651+
async with agen() as g, trio.open_nursery():
652+
await g.receive() # this is for coverage reasons
653+
raise GeneratorExit
654+
655+
656+
async def test_as_safe_channel_genexit_filter() -> None:
657+
async def wait_then_raise() -> None:
658+
try:
659+
await trio.sleep_forever()
660+
except trio.Cancelled:
661+
raise ValueError from None
662+
663+
@as_safe_channel
664+
async def agen() -> AsyncGenerator[None]:
665+
async with trio.open_nursery() as nursery:
666+
nursery.start_soon(wait_then_raise)
667+
yield
668+
669+
with pytest.RaisesGroup(ValueError):
670+
async with agen() as g:
671+
async for _ in g:
672+
break
673+
674+
675+
async def test_as_safe_channel_swallowing_extra_exceptions() -> None:
676+
async def wait_then_raise(ex: type[BaseException]) -> None:
677+
try:
678+
await trio.sleep_forever()
679+
except trio.Cancelled:
680+
raise ex from None
681+
682+
@as_safe_channel
683+
async def agen(ex: type[BaseException]) -> AsyncGenerator[None]:
684+
async with trio.open_nursery() as nursery:
685+
nursery.start_soon(wait_then_raise, ex)
686+
nursery.start_soon(wait_then_raise, GeneratorExit)
687+
yield
688+
689+
with pytest.RaisesGroup(AssertionError):
690+
async with agen(GeneratorExit) as g:
691+
async for _ in g:
692+
break
693+
694+
with pytest.RaisesGroup(ValueError, AssertionError):
695+
async with agen(ValueError) as g:
696+
async for _ in g:
697+
break

0 commit comments

Comments
 (0)