Skip to content

Commit aef4e20

Browse files
committed
Merge branch 'master' into further-cleanup-of-async-jobs-framework
2 parents 7e17747 + d7e6fd3 commit aef4e20

File tree

4 files changed

+60
-33
lines changed

4 files changed

+60
-33
lines changed

services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
from servicelib.logging_utils import log_context
99

1010
from ...dsm import get_dsm_provider
11+
from ...modules.celery.models import TaskId
1112
from ...modules.celery.utils import get_fastapi_app
1213

1314
_logger = logging.getLogger(__name__)
1415

1516

1617
async def compute_path_size(
17-
task: Task, user_id: UserID, location_id: LocationID, path: Path
18+
task: Task, task_id: TaskId, user_id: UserID, location_id: LocationID, path: Path
1819
) -> ByteSize:
20+
assert task_id # nosec
1921
with log_context(
2022
_logger,
2123
logging.INFO,

services/storage/src/simcore_service_storage/modules/celery/_celery_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pydantic import BaseModel
1111

1212
from ...models import FileMetaData
13+
from ...modules.celery.models import TaskError
1314

1415

1516
def _path_encoder(obj):
@@ -57,3 +58,4 @@ def register_celery_types() -> None:
5758
_register_pydantic_types(FileUploadCompletionBody)
5859
_register_pydantic_types(FileMetaData)
5960
_register_pydantic_types(FoldersBody)
61+
_register_pydantic_types(TaskError)

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from functools import wraps
77
from typing import Any, Concatenate, ParamSpec, TypeVar, overload
88

9-
from celery import ( # type: ignore[import-untyped]
10-
Celery,
11-
Task,
12-
)
9+
from celery import Celery # type: ignore[import-untyped]
1310
from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped]
1411
from celery.exceptions import Ignore # type: ignore[import-untyped]
1512

@@ -22,7 +19,7 @@
2219

2320
def error_handling(func: Callable[..., Any]) -> Callable[..., Any]:
2421
@wraps(func)
25-
def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
22+
def wrapper(task: AbortableTask, *args: Any, **kwargs: Any) -> Any:
2623
try:
2724
return func(task, *args, **kwargs)
2825
except Exception as exc:
@@ -31,8 +28,9 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
3128
exc_traceback = traceback.format_exc().split("\n")
3229

3330
_logger.exception(
34-
"Task %s failed with exception: %s",
31+
"Task %s failed with exception: %s:%s",
3532
task.request.id,
33+
exc_type,
3634
exc_message,
3735
)
3836

@@ -57,14 +55,14 @@ def wrapper(task: Task, *args: Any, **kwargs: Any) -> Any:
5755
def _async_task_wrapper(
5856
app: Celery,
5957
) -> Callable[
60-
[Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]],
61-
Callable[Concatenate[Task, P], R],
58+
[Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]],
59+
Callable[Concatenate[AbortableTask, P], R],
6260
]:
6361
def decorator(
64-
coro: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
65-
) -> Callable[Concatenate[Task, P], R]:
62+
coro: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]],
63+
) -> Callable[Concatenate[AbortableTask, P], R]:
6664
@wraps(coro)
67-
def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
65+
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
6866
fastapi_app = get_fastapi_app(app)
6967
_logger.debug("task id: %s", task.request.id)
7068
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
@@ -82,29 +80,29 @@ def wrapper(task: Task, *args: P.args, **kwargs: P.kwargs) -> R:
8280
@overload
8381
def define_task(
8482
app: Celery,
85-
fn: Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]],
83+
fn: Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]],
8684
task_name: str | None = None,
8785
) -> None: ...
8886

8987

9088
@overload
9189
def define_task(
9290
app: Celery,
93-
fn: Callable[Concatenate[Task, P], R],
91+
fn: Callable[Concatenate[AbortableTask, P], R],
9492
task_name: str | None = None,
9593
) -> None: ...
9694

9795

9896
def define_task( # type: ignore[misc]
9997
app: Celery,
10098
fn: (
101-
Callable[Concatenate[Task, TaskId, P], Coroutine[Any, Any, R]]
102-
| Callable[Concatenate[Task, P], R]
99+
Callable[Concatenate[AbortableTask, TaskId, P], Coroutine[Any, Any, R]]
100+
| Callable[Concatenate[AbortableTask, P], R]
103101
),
104102
task_name: str | None = None,
105103
) -> None:
106104
"""Decorator to define a celery task with error handling and abortable support"""
107-
wrapped_fn: Callable[Concatenate[Task, P], R]
105+
wrapped_fn: Callable[Concatenate[AbortableTask, P], R]
108106
if asyncio.iscoroutinefunction(fn):
109107
wrapped_fn = _async_task_wrapper(app)(fn)
110108
else:

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import logging
3+
from dataclasses import dataclass
34
from typing import Any, Final
45
from uuid import uuid4
56

@@ -12,6 +13,7 @@
1213
from pydantic import ValidationError
1314
from servicelib.logging_utils import log_context
1415

16+
from ...exceptions.errors import ConfigurationError
1517
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
1618

1719
_logger = logging.getLogger(__name__)
@@ -53,36 +55,44 @@ def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
5355
)
5456

5557

58+
@dataclass
5659
class CeleryTaskQueueClient:
57-
def __init__(self, celery_app: Celery):
58-
self._celery_app = celery_app
60+
_celery_app: Celery
5961

6062
@make_async()
6163
def send_task(
6264
self, task_name: str, *, task_context: TaskContext, **task_params
6365
) -> TaskUUID:
64-
task_uuid = uuid4()
65-
task_id = _build_task_id(task_context, task_uuid)
6666
with log_context(
6767
_logger,
6868
logging.DEBUG,
69-
msg=f"Submitting task {task_name}: {task_id=} {task_params=}",
69+
msg=f"Submit {task_name=}: {task_context=} {task_params=}",
7070
):
71+
task_uuid = uuid4()
72+
task_id = _build_task_id(task_context, task_uuid)
7173
self._celery_app.send_task(task_name, task_id=task_id, kwargs=task_params)
7274
return task_uuid
7375

76+
@staticmethod
7477
@make_async()
75-
def abort_task( # pylint: disable=R6301
76-
self, task_context: TaskContext, task_uuid: TaskUUID
77-
) -> None:
78-
task_id = _build_task_id(task_context, task_uuid)
79-
_logger.info("Aborting task %s", task_id)
80-
AbortableAsyncResult(task_id).abort()
78+
def abort_task(task_context: TaskContext, task_uuid: TaskUUID) -> None:
79+
with log_context(
80+
_logger,
81+
logging.DEBUG,
82+
msg=f"Abort task {task_uuid=}: {task_context=}",
83+
):
84+
task_id = _build_task_id(task_context, task_uuid)
85+
AbortableAsyncResult(task_id).abort()
8186

8287
@make_async()
8388
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
84-
task_id = _build_task_id(task_context, task_uuid)
85-
return self._celery_app.AsyncResult(task_id).result
89+
with log_context(
90+
_logger,
91+
logging.DEBUG,
92+
msg=f"Get task {task_uuid=}: {task_context=} result",
93+
):
94+
task_id = _build_task_id(task_context, task_uuid)
95+
return self._celery_app.AsyncResult(task_id).result
8696

8797
def _get_progress_report(
8898
self, task_context: TaskContext, task_uuid: TaskUUID
@@ -122,15 +132,30 @@ def get_task_status(
122132

123133
def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
124134
search_key = _CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
125-
redis = self._celery_app.backend.client
126-
if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")):
135+
backend_client = self._celery_app.backend.client
136+
if hasattr(backend_client, "keys") and (
137+
keys := backend_client.keys(f"{search_key}*")
138+
):
127139
return {
128140
TaskUUID(
129141
f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
130142
)
131143
for key in keys
132144
}
133-
return set()
145+
if hasattr(backend_client, "cache"):
146+
# NOTE: backend used in testing. It is a dict-like object
147+
found_keys = set()
148+
for key in backend_client.cache:
149+
str_key = key.decode(_CELERY_TASK_ID_KEY_ENCODING)
150+
if str_key.startswith(search_key):
151+
found_keys.add(
152+
TaskUUID(
153+
f"{str_key.removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}"
154+
)
155+
)
156+
return found_keys
157+
msg = f"Unsupported backend {self._celery_app.backend.__class__.__name__}"
158+
raise ConfigurationError(msg=msg)
134159

135160
@make_async()
136161
def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:

0 commit comments

Comments
 (0)