Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fe327f4
Monitor Celery tasks cancellation
giancarloromeo Apr 11, 2025
59a41cd
use Redis hash
giancarloromeo Apr 14, 2025
9db3d5b
move set_progress
giancarloromeo Apr 14, 2025
fa25ca2
setup task worker
giancarloromeo Apr 14, 2025
004ae64
handle exception
giancarloromeo Apr 14, 2025
f6860be
fix exc
giancarloromeo Apr 14, 2025
3a6f1ea
check progress_bar
giancarloromeo Apr 14, 2025
204e8d2
use task_group
giancarloromeo Apr 14, 2025
e52ef34
typecheck
giancarloromeo Apr 14, 2025
4260e22
fix name
giancarloromeo Apr 14, 2025
defac3c
Merge branch 'master' into monitor-celery-tasks-cancellation
giancarloromeo Apr 14, 2025
c579dad
continue
giancarloromeo Apr 14, 2025
33b860d
Merge branch 'monitor-celery-tasks-cancellation' of github.com:gianca…
giancarloromeo Apr 14, 2025
86f4799
typecheck
giancarloromeo Apr 14, 2025
7dcec1d
fix abort
giancarloromeo Apr 14, 2025
6cfc905
fix exceptions
giancarloromeo Apr 14, 2025
0c6bec3
comment
giancarloromeo Apr 14, 2025
09ad5e8
fix errors
giancarloromeo Apr 14, 2025
ec59035
continue
giancarloromeo Apr 14, 2025
eab796f
fix test
giancarloromeo Apr 14, 2025
ff0af05
Merge branch 'master' into monitor-celery-tasks-cancellation
giancarloromeo Apr 14, 2025
2fdc60b
Merge remote-tracking branch 'upstream/master' into monitor-celery-ta…
giancarloromeo Apr 15, 2025
dea1b71
Merge branch 'monitor-celery-tasks-cancellation' of github.com:gianca…
giancarloromeo Apr 15, 2025
4681148
fix
giancarloromeo Apr 15, 2025
4660d31
add app
giancarloromeo Apr 15, 2025
46431e4
update
giancarloromeo Apr 15, 2025
1c15a29
fix test
giancarloromeo Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ async def with_progress_bytes_iter(
self, progress_bar: ProgressBarData
) -> BytesIter:
async for chunk in self.bytes_iter_callable():
await progress_bar.update(len(chunk))
if progress_bar.is_running():
await progress_bar.update(len(chunk))
yield chunk
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ async def _task_progress_cb(
) -> None:
worker = get_celery_worker(task.app)
assert task.name # nosec
worker.set_task_progress(
task_name=task.name,
await worker.set_progress(
task_id=task_id,
report=report,
)
Expand Down Expand Up @@ -88,7 +87,7 @@ async def export_data(

async def _progress_cb(report: ProgressReport) -> None:
assert task.name # nosec
get_celery_worker(task.app).set_task_progress(task.name, task_id, report)
await get_celery_worker(task.app).set_progress(task_id, report)
_logger.debug("'%s' progress %s", task_id, report.percent_value)

async with ProgressBarData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
StorageQueryParamsBase,
UploadLinks,
)
from ...modules.celery.client import CeleryTaskQueueClient
from ...modules.celery.client import CeleryTaskClient
from ...modules.celery.models import TaskUUID
from ...simcore_s3_dsm import SimcoreS3DataManager
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
Expand Down Expand Up @@ -270,7 +270,7 @@ async def abort_upload_file(
status_code=status.HTTP_202_ACCEPTED,
)
async def complete_upload_file(
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
query_params: Annotated[StorageQueryParamsBase, Depends()],
location_id: LocationID,
file_id: StorageFileID,
Expand Down Expand Up @@ -324,7 +324,7 @@ async def complete_upload_file(
response_model=Envelope[FileUploadCompleteFutureResponse],
)
async def is_completed_upload_file(
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
query_params: Annotated[StorageQueryParamsBase, Depends()],
location_id: LocationID,
file_id: StorageFileID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from servicelib.fastapi.dependencies import get_app

from ....modules.celery import get_celery_client as _get_celery_client_from_app
from ....modules.celery.client import CeleryTaskQueueClient
from ....modules.celery.client import CeleryTaskClient


def get_celery_client(
app: Annotated[FastAPI, Depends(get_app)],
) -> CeleryTaskQueueClient:
) -> CeleryTaskClient:
return _get_celery_client_from_app(app)
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def result(

if _status.task_state == TaskState.ABORTED:
raise JobAbortedError(job_id=job_id)
if _status.task_state == TaskState.ERROR:
if _status.task_state == TaskState.FAILURE:
# fallback exception to report
exc_type = type(_result).__name__
exc_msg = f"{_result}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from ...core.settings import get_application_settings
from ._celery_types import register_celery_types
from ._common import create_app
from .backends._redis import RedisTaskMetadataStore
from .client import CeleryTaskQueueClient
from .backends._redis import RedisTaskInfoStore
from .client import CeleryTaskClient

_logger = logging.getLogger(__name__)

Expand All @@ -28,21 +28,21 @@ async def on_startup() -> None:
client_name=f"{APP_NAME}.celery_tasks",
)

app.state.celery_client = CeleryTaskQueueClient(
app.state.celery_client = CeleryTaskClient(
celery_app,
celery_settings,
RedisTaskMetadataStore(redis_client_sdk),
RedisTaskInfoStore(redis_client_sdk),
)

register_celery_types()

app.add_event_handler("startup", on_startup)


def get_celery_client(app: FastAPI) -> CeleryTaskQueueClient:
def get_celery_client(app: FastAPI) -> CeleryTaskClient:
assert hasattr(app.state, "celery_client") # nosec
celery_client = app.state.celery_client
assert isinstance(celery_client, CeleryTaskQueueClient)
assert isinstance(celery_client, CeleryTaskClient)
return celery_client


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,36 @@
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload

from celery import Celery # type: ignore[import-untyped]
from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped]
from celery.contrib.abortable import ( # type: ignore[import-untyped]
AbortableAsyncResult,
AbortableTask,
)
from celery.exceptions import Ignore # type: ignore[import-untyped]
from pydantic import NonNegativeInt
from servicelib.async_utils import cancel_wait_task

from . import get_event_loop
from .errors import encore_celery_transferrable_error
from .models import TaskId
from .models import TaskID, TaskId
from .utils import get_fastapi_app

_logger = logging.getLogger(__name__)

_DEFAULT_TASK_TIMEOUT: Final[timedelta | None] = None
_DEFAULT_MAX_RETRIES: Final[NonNegativeInt] = 3
_DEFAULT_WAIT_BEFORE_RETRY: Final[timedelta] = timedelta(seconds=5)
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = tuple()

_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = ()
_DEFAULT_ABORT_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=1)
_DEFAULT_CANCEL_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=5)

T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")


class TaskAbortedError(Exception): ...


def _async_task_wrapper(
app: Celery,
) -> Callable[
Expand All @@ -40,11 +49,45 @@ def decorator(
@wraps(coro)
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
fastapi_app = get_fastapi_app(app)
_logger.debug("task id: %s", task.request.id)
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
assert task.request.id is not None # nosec

async def run_task(task_id: TaskID) -> R:
try:
async with asyncio.TaskGroup() as tg:
main_task = tg.create_task(
coro(task, task_id, *args, **kwargs), name=f"task_{task_id}"
)

async def abort_monitor():
while not main_task.done():
if AbortableAsyncResult(task_id, app=app).is_aborted():
await cancel_wait_task(
main_task,
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
)
raise TaskAbortedError
await asyncio.sleep(
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
)

tg.create_task(abort_monitor(), name=f"abort_monitor_{task_id}")

return main_task.result()
except BaseExceptionGroup as eg:
task_aborted_errors, other_errors = eg.split(TaskAbortedError)

if task_aborted_errors:
assert task_aborted_errors is not None # nosec
assert len(task_aborted_errors.exceptions) == 1 # nosec
raise task_aborted_errors.exceptions[0] from eg

assert other_errors is not None # nosec
assert len(other_errors.exceptions) == 1 # nosec
raise other_errors.exceptions[0] from eg

return asyncio.run_coroutine_threadsafe(
coro(task, task.request.id, *args, **kwargs),
run_task(task.request.id),
get_event_loop(fastapi_app),
).result()

Expand All @@ -68,6 +111,9 @@ def decorator(
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
try:
return func(task, *args, **kwargs)
except TaskAbortedError as exc:
_logger.warning("Task %s was cancelled", task.request.id)
raise Ignore from exc
except Exception as exc:
if isinstance(exc, dont_autoretry_for):
_logger.debug("Not retrying for exception %s", type(exc).__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@
from typing import Final

from celery.result import AsyncResult # type: ignore[import-untyped]
from models_library.progress_bar import ProgressReport
from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix

_CELERY_TASK_METADATA_PREFIX: Final[str] = "celery-task-metadata-"
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"

_logger = logging.getLogger(__name__)


def _build_key(task_id: TaskID) -> str:
return _CELERY_TASK_METADATA_PREFIX + task_id
return _CELERY_TASK_INFO_PREFIX + task_id


class RedisTaskMetadataStore:
class RedisTaskInfoStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

Expand All @@ -28,13 +31,17 @@ async def exists(self, task_id: TaskID) -> bool:
assert isinstance(n, int) # nosec
return n > 0

async def get(self, task_id: TaskID) -> TaskMetadata | None:
result = await self._redis_client_sdk.redis.get(_build_key(task_id))
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
return TaskMetadata.model_validate_json(result) if result else None

async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
return ProgressReport.model_validate_json(result) if result else None

async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
search_key = (
_CELERY_TASK_METADATA_PREFIX
_CELERY_TASK_INFO_PREFIX
+ build_task_id_prefix(task_context)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
Expand All @@ -55,11 +62,22 @@ async def remove(self, task_id: TaskID) -> None:
await self._redis_client_sdk.redis.delete(_build_key(task_id))
AsyncResult(task_id).forget()

async def set(
async def set_metadata(
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
) -> None:
await self._redis_client_sdk.redis.set(
await self._redis_client_sdk.redis.hset(
name=_build_key(task_id),
key=_CELERY_TASK_METADATA_KEY,
value=task_metadata.model_dump_json(),
) # type: ignore
await self._redis_client_sdk.redis.expire(
_build_key(task_id),
task_metadata.model_dump_json(),
ex=expiry,
expiry,
)

async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
await self._redis_client_sdk.redis.hset(
name=_build_key(task_id),
key=_CELERY_TASK_PROGRESS_KEY,
value=report.model_dump_json(),
) # type: ignore
Loading
Loading