Skip to content

Commit 26ed1c6

Browse files
committed
do everything but unwrapping the exception from inside the group
1 parent bfa981c commit 26ed1c6

File tree

2 files changed

+43
-138
lines changed

2 files changed

+43
-138
lines changed

src/trio/_channel.py

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,16 @@ async def aclose(self) -> None:
458458

459459
class RecvChanWrapper(ReceiveChannel[T]):
460460
def __init__(
461-
self, recv_chan: MemoryReceiveChannel[T], send_semaphore: trio.Semaphore | None
461+
self, recv_chan: MemoryReceiveChannel[T], send_semaphore: trio.Semaphore
462462
) -> None:
463463
self.recv_chan = recv_chan
464464
self.send_semaphore = send_semaphore
465465

466-
# TODO: should this allow clones?
466+
# TODO: should this allow clones? We'd signal that by inheriting from
467+
# MemoryReceiveChannel.
467468

468469
async def receive(self) -> T:
469-
if self.send_semaphore is not None:
470-
self.send_semaphore.release()
470+
self.send_semaphore.release()
471471
return await self.recv_chan.receive()
472472

473473
async def aclose(self) -> None:
@@ -485,12 +485,9 @@ def __exit__(
485485
self.recv_chan.close()
486486

487487

488-
def background_with_channel(max_buffer_size: int | None = 0) -> Callable[
489-
[
490-
Callable[P, AsyncGenerator[T, None]],
491-
],
492-
Callable[P, AbstractAsyncContextManager[trio.abc.ReceiveChannel[T]]],
493-
]:
488+
def background_with_channel(
489+
fn: Callable[P, AsyncGenerator[T, None]],
490+
) -> Callable[P, AbstractAsyncContextManager[ReceiveChannel[T]]]:
494491
"""Decorate an async generator function to make it cancellation-safe.
495492
496493
This is mostly a drop-in replacement, except for the fact that it will
@@ -511,7 +508,7 @@ def background_with_channel(max_buffer_size: int | None = 0) -> Callable[
511508
offering only the safe interface, and you can still write your iterables
512509
with the convenience of ``yield``. For example::
513510
514-
@background_with_channel()
511+
@background_with_channel
515512
async def my_async_iterable(arg, *, kwarg=True):
516513
while ...:
517514
item = await ...
@@ -531,46 +528,30 @@ async def my_async_iterable(arg, *, kwarg=True):
531528
# Perhaps a future PEP will adopt `async with for` syntax, like
532529
# https://coconut.readthedocs.io/en/master/DOCS.html#async-with-for
533530

534-
if not isinstance(max_buffer_size, int) and max_buffer_size is not None:
535-
raise TypeError(
536-
"`max_buffer_size` must be int or None, not {type(max_buffer_size)}. "
537-
"Did you forget the parentheses in `@background_with_channel()`?"
538-
)
539-
540-
def decorator(
541-
fn: Callable[P, AsyncGenerator[T, None]],
542-
) -> Callable[P, AbstractAsyncContextManager[trio._channel.RecvChanWrapper[T]]]:
543-
@asynccontextmanager
544-
@wraps(fn)
545-
async def context_manager(
546-
*args: P.args, **kwargs: P.kwargs
547-
) -> AsyncGenerator[trio._channel.RecvChanWrapper[T], None]:
548-
max_buf_size_float = inf if max_buffer_size is None else max_buffer_size
549-
send_chan, recv_chan = trio.open_memory_channel[T](max_buf_size_float)
550-
async with trio.open_nursery(strict_exception_groups=True) as nursery:
551-
agen = fn(*args, **kwargs)
552-
send_semaphore = (
553-
None if max_buffer_size is None else trio.Semaphore(max_buffer_size)
554-
)
555-
# `nursery.start` to make sure that we will clean up send_chan & agen
556-
# If this errors we don't close `recv_chan`, but the caller
557-
# never gets access to it, so that's not a problem.
558-
await nursery.start(
559-
_move_elems_to_channel, agen, send_chan, send_semaphore
560-
)
561-
# `async with recv_chan` could eat exceptions, so use sync cm
562-
with RecvChanWrapper(recv_chan, send_semaphore) as wrapped_recv_chan:
563-
yield wrapped_recv_chan
564-
# User has exited context manager, cancel to immediately close the
565-
# abandoned generator if it's still alive.
566-
nursery.cancel_scope.cancel()
567-
568-
return context_manager
531+
@asynccontextmanager
532+
@wraps(fn)
533+
async def context_manager(
534+
*args: P.args, **kwargs: P.kwargs
535+
) -> AsyncGenerator[trio._channel.RecvChanWrapper[T], None]:
536+
send_chan, recv_chan = trio.open_memory_channel[T](0)
537+
async with trio.open_nursery(strict_exception_groups=True) as nursery:
538+
agen = fn(*args, **kwargs)
539+
send_semaphore = trio.Semaphore(0)
540+
# `nursery.start` to make sure that we will clean up send_chan & agen
541+
# If this errors we don't close `recv_chan`, but the caller
542+
# never gets access to it, so that's not a problem.
543+
await nursery.start(_move_elems_to_channel, agen, send_chan, send_semaphore)
544+
# `async with recv_chan` could eat exceptions, so use sync cm
545+
with RecvChanWrapper(recv_chan, send_semaphore) as wrapped_recv_chan:
546+
yield wrapped_recv_chan
547+
# User has exited context manager, cancel to immediately close the
548+
# abandoned generator if it's still alive.
549+
nursery.cancel_scope.cancel()
569550

570551
async def _move_elems_to_channel(
571552
agen: AsyncGenerator[T, None],
572553
send_chan: trio.MemorySendChannel[T],
573-
send_semaphore: trio.Semaphore | None,
554+
send_semaphore: trio.Semaphore,
574555
task_status: trio.TaskStatus,
575556
) -> None:
576557
# `async with send_chan` will eat exceptions,
@@ -579,22 +560,16 @@ async def _move_elems_to_channel(
579560
try:
580561
task_status.started()
581562
while True:
582-
# wait for send_chan to be unblocked
583-
if send_semaphore is not None:
584-
await send_semaphore.acquire()
563+
# wait for receiver to call next on the aiter
564+
await send_semaphore.acquire()
585565
try:
586566
value = await agen.__anext__()
587567
except StopAsyncIteration:
588568
return
589-
try:
590-
# Send the value to the channel
591-
await send_chan.send(value)
592-
except trio.BrokenResourceError:
593-
# Closing the corresponding receive channel should cause
594-
# a clean shutdown of the generator.
595-
return
569+
# Send the value to the channel
570+
await send_chan.send(value)
596571
finally:
597572
# replace try-finally with contextlib.aclosing once python39 is dropped
598573
await agen.aclose()
599574

600-
return decorator
575+
return context_manager

src/trio/_tests/test_channel.py

Lines changed: 10 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
417417

418418

419419
async def test_background_with_channel() -> None:
420-
@background_with_channel(1)
420+
@background_with_channel
421421
async def agen() -> AsyncGenerator[int]:
422422
yield 1
423423
await trio.sleep_forever() # simulate deadlock
@@ -429,7 +429,7 @@ async def agen() -> AsyncGenerator[int]:
429429

430430

431431
async def test_background_with_channel_exhaust() -> None:
432-
@background_with_channel()
432+
@background_with_channel
433433
async def agen() -> AsyncGenerator[int]:
434434
yield 1
435435

@@ -439,7 +439,7 @@ async def agen() -> AsyncGenerator[int]:
439439

440440

441441
async def test_background_with_channel_broken_resource() -> None:
442-
@background_with_channel()
442+
@background_with_channel
443443
async def agen() -> AsyncGenerator[int]:
444444
yield 1
445445
yield 2
@@ -460,7 +460,7 @@ async def agen() -> AsyncGenerator[int]:
460460
async def test_background_with_channel_cancelled() -> None:
461461
with trio.CancelScope() as cs:
462462

463-
@background_with_channel()
463+
@background_with_channel
464464
async def agen() -> AsyncGenerator[None]: # pragma: no cover
465465
raise AssertionError(
466466
"cancel before consumption means generator should not be iterated"
@@ -476,7 +476,7 @@ async def test_background_with_channel_recv_closed(
476476
) -> None:
477477
event = trio.Event()
478478

479-
@background_with_channel(1)
479+
@background_with_channel
480480
async def agen() -> AsyncGenerator[int]:
481481
await event.wait()
482482
yield 1
@@ -491,7 +491,7 @@ async def agen() -> AsyncGenerator[int]:
491491
async def test_background_with_channel_no_race() -> None:
492492
# this previously led to a race condition due to
493493
# https://github.com/python-trio/trio/issues/1559
494-
@background_with_channel()
494+
@background_with_channel
495495
async def agen() -> AsyncGenerator[int]:
496496
yield 1
497497
raise ValueError("oae")
@@ -505,7 +505,7 @@ async def agen() -> AsyncGenerator[int]:
505505
async def test_background_with_channel_buffer_size_too_small(
506506
autojump_clock: trio.testing.MockClock,
507507
) -> None:
508-
@background_with_channel(0)
508+
@background_with_channel
509509
async def agen() -> AsyncGenerator[int]:
510510
yield 1
511511
raise AssertionError(
@@ -519,27 +519,8 @@ async def agen() -> AsyncGenerator[int]:
519519
await trio.sleep_forever()
520520

521521

522-
async def test_background_with_channel_buffer_size_just_right(
523-
autojump_clock: trio.testing.MockClock,
524-
) -> None:
525-
event = trio.Event()
526-
527-
@background_with_channel(2)
528-
async def agen() -> AsyncGenerator[int]:
529-
yield 1
530-
event.set()
531-
yield 2
532-
533-
async with agen() as recv_chan:
534-
await event.wait()
535-
assert await recv_chan.__anext__() == 1
536-
assert await recv_chan.__anext__() == 2
537-
with pytest.raises(StopAsyncIteration):
538-
await recv_chan.__anext__()
539-
540-
541522
async def test_background_with_channel_no_interleave() -> None:
542-
@background_with_channel()
523+
@background_with_channel
543524
async def agen() -> AsyncGenerator[int]:
544525
yield 1
545526
raise AssertionError # pragma: no cover
@@ -549,30 +530,10 @@ async def agen() -> AsyncGenerator[int]:
549530
await trio.lowlevel.checkpoint()
550531

551532

552-
async def test_background_with_channel_multiple_errors() -> None:
553-
event = trio.Event()
554-
555-
@background_with_channel(1)
556-
async def agen() -> AsyncGenerator[int]:
557-
yield 1
558-
event.set()
559-
raise ValueError("agen")
560-
561-
with RaisesGroup(
562-
Matcher(ValueError, match="^agen$"),
563-
Matcher(TypeError, match="^iterator$"),
564-
):
565-
async with agen() as recv_chan:
566-
async for i in recv_chan: # pragma: no branch
567-
assert i == 1
568-
await event.wait()
569-
raise TypeError("iterator")
570-
571-
572533
async def test_background_with_channel_genexit_finally() -> None:
573534
events: list[str] = []
574535

575-
@background_with_channel()
536+
@background_with_channel
576537
async def agen(stuff: list[str]) -> AsyncGenerator[int]:
577538
try:
578539
yield 1
@@ -596,7 +557,7 @@ async def agen(stuff: list[str]) -> AsyncGenerator[int]:
596557

597558

598559
async def test_background_with_channel_nested_loop() -> None:
599-
@background_with_channel()
560+
@background_with_channel
600561
async def agen() -> AsyncGenerator[int]:
601562
for i in range(2):
602563
yield i
@@ -610,34 +571,3 @@ async def agen() -> AsyncGenerator[int]:
610571
assert (i, j) == (ii, jj)
611572
jj += 1
612573
ii += 1
613-
614-
615-
async def test_background_with_channel_no_parens() -> None:
616-
with pytest.raises(TypeError, match="must be int or None"):
617-
618-
@background_with_channel # type: ignore[arg-type]
619-
async def agen() -> AsyncGenerator[None]:
620-
yield # pragma: no cover
621-
622-
623-
async def test_background_with_channel_inf_buffer() -> None:
624-
event = trio.Event()
625-
626-
# agen immediately starts yielding numbers
627-
# into the buffer upon entering the cm
628-
@background_with_channel(None)
629-
async def agen() -> AsyncGenerator[int]:
630-
for i in range(10):
631-
yield i
632-
event.set()
633-
# keep agen alive to receive values
634-
await trio.sleep_forever()
635-
636-
async with agen() as recv_chan:
637-
await event.wait()
638-
j = 0
639-
async for i in recv_chan:
640-
assert i == j
641-
j += 1
642-
if j == 10:
643-
break

0 commit comments

Comments
 (0)