diff --git a/newsfragments/3324.bugfix.rst b/newsfragments/3324.bugfix.rst new file mode 100644 index 000000000..0fd5a0539 --- /dev/null +++ b/newsfragments/3324.bugfix.rst @@ -0,0 +1,2 @@ +Avoid having `trio.as_safe_channel` raise if closing the generator wrapped +`GeneratorExit` in a `BaseExceptionGroup`. diff --git a/src/trio/_channel.py b/src/trio/_channel.py index 9caa7f769..2afca9d7c 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -570,6 +570,8 @@ async def _move_elems_to_channel( # `async with send_chan` will eat exceptions, # see https://github.com/python-trio/trio/issues/1559 with send_chan: + # replace try-finally with contextlib.aclosing once python39 is + # dropped: try: task_status.started() while True: @@ -582,7 +584,32 @@ async def _move_elems_to_channel( # Send the value to the channel await send_chan.send(value) finally: - # replace try-finally with contextlib.aclosing once python39 is dropped - await agen.aclose() + # work around `.aclose()` not suppressing GeneratorExit in an + # ExceptionGroup: + # TODO: make an issue on CPython about this + try: + await agen.aclose() + except BaseExceptionGroup as exceptions: + removed, narrowed_exceptions = exceptions.split(GeneratorExit) + + # TODO: extract a helper to flatten exception groups + removed_exceptions: list[BaseException | None] = [removed] + genexits_seen = 0 + for e in removed_exceptions: + if isinstance(e, BaseExceptionGroup): + removed_exceptions.extend(e.exceptions) # noqa: B909 + else: + genexits_seen += 1 + + if genexits_seen > 1: + exc = AssertionError("More than one GeneratorExit found.") + if narrowed_exceptions is None: + narrowed_exceptions = exceptions.derive([exc]) + else: + narrowed_exceptions = narrowed_exceptions.derive( + [*narrowed_exceptions.exceptions, exc] + ) + if narrowed_exceptions is not None: + raise narrowed_exceptions from None return context_manager diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index f1556a153..85cd982a4 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -625,3 +625,73 @@ async def agen(events: list[str]) -> AsyncGenerator[None]: events.append("body cancel") raise assert events == ["body cancel", "agen cancel"] + + +async def test_as_safe_channel_genexit_exception_group() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + try: + async with trio.open_nursery(): + yield + except BaseException as e: + assert pytest.RaisesGroup(GeneratorExit).matches(e) # noqa: PT017 + raise + + async with agen() as g: + async for _ in g: + break + + +async def test_as_safe_channel_does_not_suppress_nested_genexit() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + yield + + with pytest.RaisesGroup(GeneratorExit): + async with agen() as g, trio.open_nursery(): + await g.receive() # this is for coverage reasons + raise GeneratorExit + + +async def test_as_safe_channel_genexit_filter() -> None: + async def wait_then_raise() -> None: + try: + await trio.sleep_forever() + except trio.Cancelled: + raise ValueError from None + + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + async with trio.open_nursery() as nursery: + nursery.start_soon(wait_then_raise) + yield + + with pytest.RaisesGroup(ValueError): + async with agen() as g: + async for _ in g: + break + + +async def test_as_safe_channel_swallowing_extra_exceptions() -> None: + async def wait_then_raise(ex: type[BaseException]) -> None: + try: + await trio.sleep_forever() + except trio.Cancelled: + raise ex from None + + @as_safe_channel + async def agen(ex: type[BaseException]) -> AsyncGenerator[None]: + async with trio.open_nursery() as nursery: + nursery.start_soon(wait_then_raise, ex) + nursery.start_soon(wait_then_raise, GeneratorExit) + yield + + with pytest.RaisesGroup(AssertionError): + async with agen(GeneratorExit) as g: + async for _ in g: + break + + with pytest.RaisesGroup(ValueError, AssertionError): + async with agen(ValueError) as g: + async for _ in g: + break