1818)
1919from ._errors import (
2020 SemaphoreAcquisitionError ,
21+ SemaphoreLostError ,
2122 SemaphoreNotAcquiredError ,
2223)
2324from ._semaphore import DistributedSemaphore
@@ -42,26 +43,42 @@ async def _managed_semaphore_execution(
4243 raise SemaphoreAcquisitionError (name = semaphore_key , capacity = semaphore .capacity )
4344
4445 try :
45- started_event = asyncio .Event ()
46-
47- # Create auto-renewal task
48- @periodic (interval = ttl / 3 , raise_on_error = True )
49- async def _periodic_renewer () -> None :
50- await semaphore .reacquire ()
51- if not started_event .is_set ():
52- started_event .set ()
53-
54- renewal_task = asyncio .create_task (
55- _periodic_renewer (),
56- name = f"semaphore/autorenewal_{ semaphore_key } _{ semaphore .instance_id } " ,
57- )
58- # Wait for first renewal to complete (ensures task is running)
59- await started_event .wait ()
46+ # Use TaskGroup for proper exception propagation
47+ async with asyncio .TaskGroup () as tg :
48+ started_event = asyncio .Event ()
49+
50+ # Create auto-renewal task
51+ @periodic (interval = ttl / 3 , raise_on_error = True )
52+ async def _periodic_renewer () -> None :
53+ await semaphore .reacquire ()
54+ if not started_event .is_set ():
55+ started_event .set ()
56+
57+ # Start the renewal task
58+ renewal_task = tg .create_task (
59+ _periodic_renewer (),
60+ name = f"semaphore/autorenewal_{ semaphore_key } _{ semaphore .instance_id } " ,
61+ )
62+
63+ # Wait for first renewal to complete (ensures task is running)
64+ await started_event .wait ()
65+
66+ # Yield control back to caller
67+ yield
68+
69+ # Cancel renewal task when execution is done
70+ await cancel_wait_task (renewal_task , max_delay = None )
6071
61- yield
72+ except BaseExceptionGroup as eg :
73+ semaphore_lost_errors , other_errors = eg .split (SemaphoreLostError )
74+ # If there are any other errors, re-raise them
75+ if other_errors :
76+ assert len (other_errors .exceptions ) == 1 # nosec
77+ raise other_errors .exceptions [0 ] from eg
6278
63- # Cancel renewal task when execution is done
64- await cancel_wait_task (renewal_task , max_delay = None )
79+ assert semaphore_lost_errors is not None # nosec
80+ assert len (semaphore_lost_errors .exceptions ) == 1 # nosec
81+ raise semaphore_lost_errors .exceptions [0 ] from eg
6582
6683 finally :
6784 # Always attempt to release the semaphore
0 commit comments