Skip to content

Commit 8bc8f3c

Browse files
authored
🎨 ♻️ Improves task cancellation with new cancel_and_wait utility (#7956)
1 parent 2d1134f commit 8bc8f3c

File tree

41 files changed

+252
-131
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+252
-131
lines changed

packages/celery-library/src/celery_library/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
AbortableTask,
1313
)
1414
from celery.exceptions import Ignore # type: ignore[import-untyped]
15+
from common_library.async_tools import cancel_wait_task
1516
from pydantic import NonNegativeInt
16-
from servicelib.async_utils import cancel_wait_task
1717
from servicelib.celery.models import TaskID
1818

1919
from .errors import encode_celery_transferrable_error

packages/common-library/src/common_library/async_tools.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import asyncio
2+
import datetime
23
import functools
3-
from collections.abc import Awaitable, Callable
4+
import logging
5+
from collections.abc import Awaitable, Callable, Coroutine
46
from concurrent.futures import Executor
7+
from functools import wraps
58
from inspect import isawaitable
6-
from typing import ParamSpec, TypeVar, overload
9+
from typing import Any, ParamSpec, TypeVar, overload
10+
11+
_logger = logging.getLogger(__name__)
712

813
R = TypeVar("R")
914
P = ParamSpec("P")
@@ -62,3 +67,69 @@ async def maybe_await(
6267
return await obj
6368
assert not isawaitable(obj) # nosec
6469
return obj
70+
71+
72+
async def cancel_wait_task(
73+
task: asyncio.Task, *, max_delay: float | None = None
74+
) -> None:
75+
"""Cancels the given task and waits for it to complete
76+
77+
Arguments:
78+
task -- task to be canceled
79+
80+
81+
Keyword Arguments:
82+
max_delay -- duration (in seconds) to wait before giving
83+
up the cancellation. This timeout should be an upper bound to the
84+
time needed for the task to cleanup after being canceled and
85+
avoids that the cancellation takes forever. If None the timeout is not
86+
set. (default: {None})
87+
88+
Raises:
89+
TimeoutError: raised if cannot cancel the task.
90+
CancelledError: raised ONLY if owner is being cancelled.
91+
"""
92+
93+
cancelling = task.cancel()
94+
if not cancelling:
95+
return # task was alredy cancelled
96+
97+
assert task.cancelling() # nosec
98+
assert not task.cancelled() # nosec
99+
100+
try:
101+
102+
await asyncio.shield(
103+
# NOTE shield ensures that cancellation of the caller function won't stop you
104+
# from observing the cancellation/finalization of task.
105+
asyncio.wait_for(task, timeout=max_delay)
106+
)
107+
108+
except asyncio.CancelledError:
109+
if not task.cancelled():
110+
# task owner function is being cancelled -> propagate cancellation
111+
raise
112+
113+
# else: task cancellation is complete, we can safely ignore it
114+
_logger.debug(
115+
"Task %s cancellation is complete",
116+
task.get_name(),
117+
)
118+
119+
120+
def delayed_start(
121+
delay: datetime.timedelta,
122+
) -> Callable[
123+
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
124+
]:
125+
def _decorator(
126+
func: Callable[P, Coroutine[Any, Any, R]],
127+
) -> Callable[P, Coroutine[Any, Any, R]]:
128+
@wraps(func)
129+
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
130+
await asyncio.sleep(delay.total_seconds())
131+
return await func(*args, **kwargs)
132+
133+
return _wrapper
134+
135+
return _decorator

packages/common-library/tests/test_async_tools.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import asyncio
2+
import time
23
from concurrent.futures import ThreadPoolExecutor
4+
from datetime import timedelta
35
from typing import Any
46

57
import pytest
6-
from common_library.async_tools import make_async, maybe_await
8+
from common_library.async_tools import (
9+
cancel_wait_task,
10+
delayed_start,
11+
make_async,
12+
maybe_await,
13+
)
714

815

916
@make_async()
@@ -13,7 +20,8 @@ def sync_function(x: int, y: int) -> int:
1320

1421
@make_async()
1522
def sync_function_with_exception() -> None:
16-
raise ValueError("This is an error!")
23+
msg = "This is an error!"
24+
raise ValueError(msg)
1725

1826

1927
@pytest.mark.asyncio
@@ -93,3 +101,118 @@ def fetchone(self) -> Any: # pylint: disable=no-self-use
93101

94102
sync_result = await maybe_await(SyncResultProxy().fetchone())
95103
assert sync_result == {"id": 2, "name": "test2"}
104+
105+
106+
async def test_cancel_and_wait():
107+
state = {"started": False, "cancelled": False, "cleaned_up": False}
108+
SLEEP_TIME = 5 # seconds
109+
110+
async def coro():
111+
try:
112+
state["started"] = True
113+
await asyncio.sleep(SLEEP_TIME)
114+
except asyncio.CancelledError:
115+
state["cancelled"] = True
116+
raise
117+
finally:
118+
state["cleaned_up"] = True
119+
120+
task = asyncio.create_task(coro())
121+
await asyncio.sleep(0.1) # Let coro start
122+
123+
start = time.time()
124+
await cancel_wait_task(task)
125+
126+
elapsed = time.time() - start
127+
assert elapsed < SLEEP_TIME, "Task should be cancelled quickly"
128+
assert task.done()
129+
assert task.cancelled()
130+
assert state["started"]
131+
assert state["cancelled"]
132+
assert state["cleaned_up"]
133+
134+
135+
async def test_cancel_and_wait_propagates_external_cancel():
136+
"""
137+
This test ensures that if the caller of cancel_and_wait is cancelled,
138+
the CancelledError is not swallowed.
139+
"""
140+
141+
async def coro():
142+
try:
143+
await asyncio.sleep(4)
144+
except asyncio.CancelledError:
145+
await asyncio.sleep(1) # simulate cleanup
146+
raise
147+
148+
inner_task = asyncio.create_task(coro())
149+
150+
async def outer_coro():
151+
try:
152+
await cancel_wait_task(inner_task)
153+
except asyncio.CancelledError:
154+
assert (
155+
not inner_task.cancelled()
156+
), "Internal Task DOES NOT RAISE CancelledError"
157+
raise
158+
159+
# Cancel the wrapper after a short delay
160+
outer_task = asyncio.create_task(outer_coro())
161+
await asyncio.sleep(0.1)
162+
outer_task.cancel()
163+
164+
with pytest.raises(asyncio.CancelledError):
165+
await outer_task
166+
167+
# Ensure the task was cancelled
168+
assert inner_task.cancelled() is False, "Task should not be cancelled initially"
169+
170+
done_event = asyncio.Event()
171+
172+
def on_done(_):
173+
done_event.set()
174+
175+
inner_task.add_done_callback(on_done)
176+
await done_event.wait()
177+
178+
179+
async def test_cancel_and_wait_timeout_on_slow_cleanup():
180+
"""Test that cancel_and_wait raises TimeoutError when cleanup takes longer than max_delay"""
181+
182+
CLEANUP_TIME = 2 # seconds
183+
184+
async def slow_cleanup_coro():
185+
try:
186+
await asyncio.sleep(10) # Long running task
187+
except asyncio.CancelledError:
188+
# Simulate slow cleanup that exceeds max_delay!
189+
await asyncio.sleep(CLEANUP_TIME)
190+
raise
191+
192+
task = asyncio.create_task(slow_cleanup_coro())
193+
await asyncio.sleep(0.1) # Let the task start
194+
195+
# Cancel with a max_delay shorter than cleanup time
196+
with pytest.raises(TimeoutError):
197+
await cancel_wait_task(
198+
task, max_delay=CLEANUP_TIME / 10
199+
) # 0.2 seconds < 2 seconds cleanup
200+
201+
assert task.cancelled()
202+
203+
204+
async def test_with_delay():
205+
@delayed_start(timedelta(seconds=0.2))
206+
async def decorated_awaitable() -> int:
207+
return 42
208+
209+
assert await decorated_awaitable() == 42
210+
211+
async def another_awaitable() -> int:
212+
return 42
213+
214+
decorated_another_awaitable = delayed_start(timedelta(seconds=0.2))(
215+
another_awaitable
216+
)
217+
218+
assert await decorated_another_awaitable() == 42

packages/service-library/src/servicelib/async_utils.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import asyncio
2-
import contextlib
3-
import datetime
42
import logging
53
from collections import deque
6-
from collections.abc import Awaitable, Callable, Coroutine
4+
from collections.abc import Awaitable, Callable
75
from contextlib import suppress
86
from dataclasses import dataclass
97
from functools import wraps
@@ -212,40 +210,3 @@ async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None:
212210
return wrapper
213211

214212
return decorator
215-
216-
217-
def delayed_start(
218-
delay: datetime.timedelta,
219-
) -> Callable[
220-
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
221-
]:
222-
def _decorator(
223-
func: Callable[P, Coroutine[Any, Any, R]],
224-
) -> Callable[P, Coroutine[Any, Any, R]]:
225-
@wraps(func)
226-
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
227-
await asyncio.sleep(delay.total_seconds())
228-
return await func(*args, **kwargs)
229-
230-
return _wrapper
231-
232-
return _decorator
233-
234-
235-
async def cancel_wait_task(
236-
task: asyncio.Task,
237-
*,
238-
max_delay: float | None = None,
239-
) -> None:
240-
"""Cancel a asyncio.Task and waits for it to finish.
241-
242-
:param task: task to be canceled
243-
:param max_delay: duration (in seconds) to wait before giving
244-
up the cancellation. If None it waits forever.
245-
:raises TimeoutError: raised if cannot cancel the task.
246-
"""
247-
248-
task.cancel()
249-
async with asyncio.timeout(max_delay):
250-
with contextlib.suppress(asyncio.CancelledError):
251-
await task

packages/service-library/src/servicelib/background_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
77
from typing import Any, Final, ParamSpec, TypeVar
88

9+
from common_library.async_tools import cancel_wait_task, delayed_start
910
from tenacity import TryAgain, before_sleep_log, retry, retry_if_exception_type
1011
from tenacity.wait import wait_fixed
1112

12-
from .async_utils import cancel_wait_task, delayed_start
1313
from .logging_utils import log_catch, log_context
1414

1515
_logger = logging.getLogger(__name__)
@@ -142,4 +142,4 @@ async def periodic_task(
142142
if asyncio_task is not None:
143143
# NOTE: this stopping is shielded to prevent the cancellation to propagate
144144
# into the stopping procedure
145-
await asyncio.shield(cancel_wait_task(asyncio_task, max_delay=stop_timeout))
145+
await cancel_wait_task(asyncio_task, max_delay=stop_timeout)

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from typing import Any, Final, Protocol, TypeAlias
1010
from uuid import uuid4
1111

12+
from common_library.async_tools import cancel_wait_task
1213
from models_library.api_schemas_long_running_tasks.base import TaskProgress
1314
from pydantic import PositiveFloat
14-
from servicelib.async_utils import cancel_wait_task
1515
from servicelib.background_task import create_periodic_task
1616
from servicelib.logging_utils import log_catch
1717

packages/service-library/src/servicelib/redis/_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
import redis.asyncio as aioredis
1010
import redis.exceptions
11+
from common_library.async_tools import cancel_wait_task
1112
from redis.asyncio.lock import Lock
1213
from redis.asyncio.retry import Retry
1314
from redis.backoff import ExponentialBackoff
1415

15-
from ..async_utils import cancel_wait_task
1616
from ..background_task import periodic
1717
from ..logging_utils import log_catch, log_context
1818
from ._constants import (

packages/service-library/tests/rabbitmq/test_rabbitmq_rpc_interfaces_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass, field
55

66
import pytest
7+
from common_library.async_tools import cancel_wait_task
78
from faker import Faker
89
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
910
AsyncJobGet,
@@ -16,7 +17,6 @@
1617
from models_library.progress_bar import ProgressReport
1718
from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace
1819
from pydantic import TypeAdapter
19-
from servicelib.async_utils import cancel_wait_task
2020
from servicelib.rabbitmq import RabbitMQRPCClient, RemoteMethodNotRegisteredError
2121
from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import (
2222
list_jobs,

packages/service-library/tests/redis/test_project_lock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from uuid import UUID
1111

1212
import pytest
13+
from common_library.async_tools import cancel_wait_task
1314
from faker import Faker
1415
from models_library.projects import ProjectID
1516
from models_library.projects_access import Owner
1617
from models_library.projects_state import ProjectLocked, ProjectStatus
17-
from servicelib.async_utils import cancel_wait_task
1818
from servicelib.redis import (
1919
ProjectLockError,
2020
RedisClientSDK,

0 commit comments

Comments
 (0)