Skip to content
Merged
2 changes: 1 addition & 1 deletion packages/celery-library/src/celery_library/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
AbortableTask,
)
from celery.exceptions import Ignore # type: ignore[import-untyped]
from common_library.async_tools import cancel_wait_task
from pydantic import NonNegativeInt
from servicelib.async_utils import cancel_wait_task
from servicelib.celery.models import TaskID

from .errors import encode_celery_transferrable_error
Expand Down
75 changes: 73 additions & 2 deletions packages/common-library/src/common_library/async_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import asyncio
import datetime
import functools
from collections.abc import Awaitable, Callable
import logging
from collections.abc import Awaitable, Callable, Coroutine
from concurrent.futures import Executor
from functools import wraps
from inspect import isawaitable
from typing import ParamSpec, TypeVar, overload
from typing import Any, ParamSpec, TypeVar, overload

_logger = logging.getLogger(__name__)

R = TypeVar("R")
P = ParamSpec("P")
Expand Down Expand Up @@ -62,3 +67,69 @@ async def maybe_await(
return await obj
assert not isawaitable(obj) # nosec
return obj


async def cancel_wait_task(
task: asyncio.Task, *, max_delay: float | None = None
) -> None:
"""Cancels the given task and waits for it to complete

Arguments:
task -- task to be canceled


Keyword Arguments:
max_delay -- duration (in seconds) to wait before giving
up the cancellation. This timeout should be an upper bound to the
time needed for the task to cleanup after being canceled and
avoids that the cancellation takes forever. If None the timeout is not
set. (default: {None})

Raises:
TimeoutError: raised if cannot cancel the task.
CancelledError: raised ONLY if owner is being cancelled.
"""

cancelling = task.cancel()
if not cancelling:
return # task was alredy cancelled

assert task.cancelling() # nosec
assert not task.cancelled() # nosec

try:

await asyncio.shield(
# NOTE shield ensures that cancellation of the caller function won't stop you
# from observing the cancellation/finalization of task.
asyncio.wait_for(task, timeout=max_delay)
)

except asyncio.CancelledError:
if not task.cancelled():
# task owner function is being cancelled -> propagate cancellation
raise

# else: task cancellation is complete, we can safely ignore it
_logger.debug(
"Task %s cancellation is complete",
task.get_name(),
)


def delayed_start(
delay: datetime.timedelta,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
def _decorator(
func: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
await asyncio.sleep(delay.total_seconds())
return await func(*args, **kwargs)

return _wrapper

return _decorator
127 changes: 125 additions & 2 deletions packages/common-library/tests/test_async_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Any

import pytest
from common_library.async_tools import make_async, maybe_await
from common_library.async_tools import (
cancel_wait_task,
delayed_start,
make_async,
maybe_await,
)


@make_async()
Expand All @@ -13,7 +20,8 @@ def sync_function(x: int, y: int) -> int:

@make_async()
def sync_function_with_exception() -> None:
raise ValueError("This is an error!")
msg = "This is an error!"
raise ValueError(msg)


@pytest.mark.asyncio
Expand Down Expand Up @@ -93,3 +101,118 @@ def fetchone(self) -> Any: # pylint: disable=no-self-use

sync_result = await maybe_await(SyncResultProxy().fetchone())
assert sync_result == {"id": 2, "name": "test2"}


async def test_cancel_and_wait():
state = {"started": False, "cancelled": False, "cleaned_up": False}
SLEEP_TIME = 5 # seconds

async def coro():
try:
state["started"] = True
await asyncio.sleep(SLEEP_TIME)
except asyncio.CancelledError:
state["cancelled"] = True
raise
finally:
state["cleaned_up"] = True

task = asyncio.create_task(coro())
await asyncio.sleep(0.1) # Let coro start

start = time.time()
await cancel_wait_task(task)

elapsed = time.time() - start
assert elapsed < SLEEP_TIME, "Task should be cancelled quickly"
assert task.done()
assert task.cancelled()
assert state["started"]
assert state["cancelled"]
assert state["cleaned_up"]


async def test_cancel_and_wait_propagates_external_cancel():
"""
This test ensures that if the caller of cancel_and_wait is cancelled,
the CancelledError is not swallowed.
"""

async def coro():
try:
await asyncio.sleep(4)
except asyncio.CancelledError:
await asyncio.sleep(1) # simulate cleanup
raise

inner_task = asyncio.create_task(coro())

async def outer_coro():
try:
await cancel_wait_task(inner_task)
except asyncio.CancelledError:
assert (
not inner_task.cancelled()
), "Internal Task DOES NOT RAISE CancelledError"
raise

# Cancel the wrapper after a short delay
outer_task = asyncio.create_task(outer_coro())
await asyncio.sleep(0.1)
outer_task.cancel()

with pytest.raises(asyncio.CancelledError):
await outer_task

# Ensure the task was cancelled
assert inner_task.cancelled() is False, "Task should not be cancelled initially"

done_event = asyncio.Event()

def on_done(_):
done_event.set()

inner_task.add_done_callback(on_done)
await done_event.wait()


async def test_cancel_and_wait_timeout_on_slow_cleanup():
"""Test that cancel_and_wait raises TimeoutError when cleanup takes longer than max_delay"""

CLEANUP_TIME = 2 # seconds

async def slow_cleanup_coro():
try:
await asyncio.sleep(10) # Long running task
except asyncio.CancelledError:
# Simulate slow cleanup that exceeds max_delay!
await asyncio.sleep(CLEANUP_TIME)
raise

task = asyncio.create_task(slow_cleanup_coro())
await asyncio.sleep(0.1) # Let the task start

# Cancel with a max_delay shorter than cleanup time
with pytest.raises(TimeoutError):
await cancel_wait_task(
task, max_delay=CLEANUP_TIME / 10
) # 0.2 seconds < 2 seconds cleanup

assert task.cancelled()


async def test_with_delay():
@delayed_start(timedelta(seconds=0.2))
async def decorated_awaitable() -> int:
return 42

assert await decorated_awaitable() == 42

async def another_awaitable() -> int:
return 42

decorated_another_awaitable = delayed_start(timedelta(seconds=0.2))(
another_awaitable
)

assert await decorated_another_awaitable() == 42
41 changes: 1 addition & 40 deletions packages/service-library/src/servicelib/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import contextlib
import datetime
import logging
from collections import deque
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable
from contextlib import suppress
from dataclasses import dataclass
from functools import wraps
Expand Down Expand Up @@ -212,40 +210,3 @@ async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None:
return wrapper

return decorator


def delayed_start(
delay: datetime.timedelta,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
def _decorator(
func: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
@wraps(func)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
await asyncio.sleep(delay.total_seconds())
return await func(*args, **kwargs)

return _wrapper

return _decorator


async def cancel_wait_task(
task: asyncio.Task,
*,
max_delay: float | None = None,
) -> None:
"""Cancel a asyncio.Task and waits for it to finish.

:param task: task to be canceled
:param max_delay: duration (in seconds) to wait before giving
up the cancellation. If None it waits forever.
:raises TimeoutError: raised if cannot cancel the task.
"""

task.cancel()
async with asyncio.timeout(max_delay):
with contextlib.suppress(asyncio.CancelledError):
await task
4 changes: 2 additions & 2 deletions packages/service-library/src/servicelib/background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine
from typing import Any, Final, ParamSpec, TypeVar

from common_library.async_tools import cancel_wait_task, delayed_start
from tenacity import TryAgain, before_sleep_log, retry, retry_if_exception_type
from tenacity.wait import wait_fixed

from .async_utils import cancel_wait_task, delayed_start
from .logging_utils import log_catch, log_context

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,4 +142,4 @@ async def periodic_task(
if asyncio_task is not None:
# NOTE: this stopping is shielded to prevent the cancellation to propagate
# into the stopping procedure
await asyncio.shield(cancel_wait_task(asyncio_task, max_delay=stop_timeout))
await cancel_wait_task(asyncio_task, max_delay=stop_timeout)
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Any, Final, Protocol, TypeAlias
from uuid import uuid4

from common_library.async_tools import cancel_wait_task
from models_library.api_schemas_long_running_tasks.base import TaskProgress
from pydantic import PositiveFloat
from servicelib.async_utils import cancel_wait_task
from servicelib.background_task import create_periodic_task
from servicelib.logging_utils import log_catch

Expand Down
2 changes: 1 addition & 1 deletion packages/service-library/src/servicelib/redis/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import redis.asyncio as aioredis
import redis.exceptions
from common_library.async_tools import cancel_wait_task
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff

from ..async_utils import cancel_wait_task
from ..background_task import periodic
from ..logging_utils import log_catch, log_context
from ._constants import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field

import pytest
from common_library.async_tools import cancel_wait_task
from faker import Faker
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobGet,
Expand All @@ -16,7 +17,6 @@
from models_library.progress_bar import ProgressReport
from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace
from pydantic import TypeAdapter
from servicelib.async_utils import cancel_wait_task
from servicelib.rabbitmq import RabbitMQRPCClient, RemoteMethodNotRegisteredError
from servicelib.rabbitmq.rpc_interfaces.async_jobs.async_jobs import (
list_jobs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from uuid import UUID

import pytest
from common_library.async_tools import cancel_wait_task
from faker import Faker
from models_library.projects import ProjectID
from models_library.projects_access import Owner
from models_library.projects_state import ProjectLocked, ProjectStatus
from servicelib.async_utils import cancel_wait_task
from servicelib.redis import (
ProjectLockError,
RedisClientSDK,
Expand Down
Loading
Loading