Skip to content

Commit fdd6e16

Browse files
committed
Revert "no need for task group here"
This reverts commit 876c682.
1 parent b946905 commit fdd6e16

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

packages/service-library/src/servicelib/redis/_semaphore_decorator.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from ._errors import (
2020
SemaphoreAcquisitionError,
21+
SemaphoreLostError,
2122
SemaphoreNotAcquiredError,
2223
)
2324
from ._semaphore import DistributedSemaphore
@@ -42,26 +43,42 @@ async def _managed_semaphore_execution(
4243
raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity)
4344

4445
try:
45-
started_event = asyncio.Event()
46-
47-
# Create auto-renewal task
48-
@periodic(interval=ttl / 3, raise_on_error=True)
49-
async def _periodic_renewer() -> None:
50-
await semaphore.reacquire()
51-
if not started_event.is_set():
52-
started_event.set()
53-
54-
renewal_task = asyncio.create_task(
55-
_periodic_renewer(),
56-
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
57-
)
58-
# Wait for first renewal to complete (ensures task is running)
59-
await started_event.wait()
46+
# Use TaskGroup for proper exception propagation
47+
async with asyncio.TaskGroup() as tg:
48+
started_event = asyncio.Event()
49+
50+
# Create auto-renewal task
51+
@periodic(interval=ttl / 3, raise_on_error=True)
52+
async def _periodic_renewer() -> None:
53+
await semaphore.reacquire()
54+
if not started_event.is_set():
55+
started_event.set()
56+
57+
# Start the renewal task
58+
renewal_task = tg.create_task(
59+
_periodic_renewer(),
60+
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
61+
)
62+
63+
# Wait for first renewal to complete (ensures task is running)
64+
await started_event.wait()
65+
66+
# Yield control back to caller
67+
yield
68+
69+
# Cancel renewal task when execution is done
70+
await cancel_wait_task(renewal_task, max_delay=None)
6071

61-
yield
72+
except BaseExceptionGroup as eg:
73+
semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError)
74+
# If there are any other errors, re-raise them
75+
if other_errors:
76+
assert len(other_errors.exceptions) == 1 # nosec
77+
raise other_errors.exceptions[0] from eg
6278

63-
# Cancel renewal task when execution is done
64-
await cancel_wait_task(renewal_task, max_delay=None)
79+
assert semaphore_lost_errors is not None # nosec
80+
assert len(semaphore_lost_errors.exceptions) == 1 # nosec
81+
raise semaphore_lost_errors.exceptions[0] from eg
6582

6683
finally:
6784
# Always attempt to release the semaphore

0 commit comments

Comments
 (0)