Skip to content

Commit bcfbef9

Browse files
committed
no need for task group here
1 parent badf633 commit bcfbef9

File tree

1 file changed

+18
-35
lines changed

1 file changed

+18
-35
lines changed

packages/service-library/src/servicelib/redis/_semaphore_decorator.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from ._errors import (
2020
SemaphoreAcquisitionError,
21-
SemaphoreLostError,
2221
SemaphoreNotAcquiredError,
2322
)
2423
from ._semaphore import DistributedSemaphore
@@ -43,42 +42,26 @@ async def _managed_semaphore_execution(
4342
raise SemaphoreAcquisitionError(name=semaphore_key, capacity=semaphore.capacity)
4443

4544
try:
46-
# Use TaskGroup for proper exception propagation
47-
async with asyncio.TaskGroup() as tg:
48-
started_event = asyncio.Event()
49-
50-
# Create auto-renewal task
51-
@periodic(interval=ttl / 3, raise_on_error=True)
52-
async def _periodic_renewer() -> None:
53-
await semaphore.reacquire()
54-
if not started_event.is_set():
55-
started_event.set()
56-
57-
# Start the renewal task
58-
renewal_task = tg.create_task(
59-
_periodic_renewer(),
60-
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
61-
)
62-
63-
# Wait for first renewal to complete (ensures task is running)
64-
await started_event.wait()
65-
66-
# Yield control back to caller
67-
yield
68-
69-
# Cancel renewal task when execution is done
70-
await cancel_wait_task(renewal_task, max_delay=None)
45+
started_event = asyncio.Event()
46+
47+
# Create auto-renewal task
48+
@periodic(interval=ttl / 3, raise_on_error=True)
49+
async def _periodic_renewer() -> None:
50+
await semaphore.reacquire()
51+
if not started_event.is_set():
52+
started_event.set()
53+
54+
renewal_task = asyncio.create_task(
55+
_periodic_renewer(),
56+
name=f"semaphore/autorenewal_{semaphore_key}_{semaphore.instance_id}",
57+
)
58+
# Wait for first renewal to complete (ensures task is running)
59+
await started_event.wait()
7160

72-
except BaseExceptionGroup as eg:
73-
semaphore_lost_errors, other_errors = eg.split(SemaphoreLostError)
74-
# If there are any other errors, re-raise them
75-
if other_errors:
76-
assert len(other_errors.exceptions) == 1 # nosec
77-
raise other_errors.exceptions[0] from eg
61+
yield
7862

79-
assert semaphore_lost_errors is not None # nosec
80-
assert len(semaphore_lost_errors.exceptions) == 1 # nosec
81-
raise semaphore_lost_errors.exceptions[0] from eg
63+
# Cancel renewal task when execution is done
64+
await cancel_wait_task(renewal_task, max_delay=None)
8265

8366
finally:
8467
# Always attempt to release the semaphore

0 commit comments

Comments
 (0)