Skip to content

Commit b51a93d

Browse files
✨ Add Celery task manager to Web Server ⚠️ (#8436)
On behalf of @giancarloromeo
1 parent c7bdff6 commit b51a93d

File tree

39 files changed

+919
-275
lines changed

39 files changed

+919
-275
lines changed

api/specs/web-server/_long_running_tasks.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,24 @@
1212
from servicelib.aiohttp.long_running_tasks._routes import _PathParam
1313
from servicelib.long_running_tasks.models import TaskGet, TaskStatus
1414
from simcore_service_webserver._meta import API_VTAG
15-
from simcore_service_webserver.tasks._exception_handlers import (
16-
_TO_HTTP_ERROR_MAP as export_data_http_error_map,
15+
from simcore_service_webserver.tasks._controller._rest_exceptions import (
16+
_TO_HTTP_ERROR_MAP,
1717
)
1818

1919
router = APIRouter(
2020
prefix=f"/{API_VTAG}",
2121
tags=[
2222
"long-running-tasks",
2323
],
24+
responses={
25+
i.status_code: {"model": EnvelopedError} for i in _TO_HTTP_ERROR_MAP.values()
26+
},
2427
)
2528

26-
_export_data_responses: dict[int | str, dict[str, Any]] = {
27-
i.status_code: {"model": EnvelopedError}
28-
for i in export_data_http_error_map.values()
29-
}
30-
3129

3230
@router.get(
3331
"/tasks",
3432
response_model=Envelope[list[TaskGet]],
35-
responses=_export_data_responses,
3633
)
3734
def get_async_jobs():
3835
"""Lists all long running tasks"""
@@ -41,7 +38,6 @@ def get_async_jobs():
4138
@router.get(
4239
"/tasks/{task_id}",
4340
response_model=Envelope[TaskStatus],
44-
responses=_export_data_responses,
4541
)
4642
def get_async_job_status(
4743
_path_params: Annotated[_PathParam, Depends()],
@@ -51,7 +47,6 @@ def get_async_job_status(
5147

5248
@router.delete(
5349
"/tasks/{task_id}",
54-
responses=_export_data_responses,
5550
status_code=status.HTTP_204_NO_CONTENT,
5651
)
5752
def cancel_async_job(
@@ -63,7 +58,6 @@ def cancel_async_job(
6358
@router.get(
6459
"/tasks/{task_id}/result",
6560
response_model=Any,
66-
responses=_export_data_responses,
6761
)
6862
def get_async_job_result(
6963
_path_params: Annotated[_PathParam, Depends()],

api/specs/web-server/_storage.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# pylint: disable=too-many-arguments
55

66

7-
from typing import Annotated, Any, TypeAlias
7+
from typing import Annotated, TypeAlias
88

99
from fastapi import APIRouter, Depends, Query, status
1010
from models_library.api_schemas_long_running_tasks.tasks import (
@@ -35,8 +35,8 @@
3535
from servicelib.fastapi.rest_pagination import CustomizedPathsCursorPage
3636
from simcore_service_webserver._meta import API_VTAG
3737
from simcore_service_webserver.storage.schemas import DatasetMetaData, FileMetaData
38-
from simcore_service_webserver.tasks._exception_handlers import (
39-
_TO_HTTP_ERROR_MAP as export_data_http_error_map,
38+
from simcore_service_webserver.tasks._controller._rest_exceptions import (
39+
_TO_HTTP_ERROR_MAP,
4040
)
4141

4242
router = APIRouter(
@@ -220,19 +220,14 @@ async def is_completed_upload_file(
220220
"""Returns state of upload completion"""
221221

222222

223-
# data export
224-
_export_data_responses: dict[int | str, dict[str, Any]] = {
225-
i.status_code: {"model": EnvelopedError}
226-
for i in export_data_http_error_map.values()
227-
}
228-
229-
230223
@router.post(
231224
"/storage/locations/{location_id}/export-data",
232225
response_model=Envelope[TaskGet],
233226
name="export_data",
234227
description="Export data",
235-
responses=_export_data_responses,
228+
responses={
229+
i.status_code: {"model": EnvelopedError} for i in _TO_HTTP_ERROR_MAP.values()
230+
},
236231
)
237232
async def export_data(export_data: DataExportPost, location_id: LocationID):
238233
"""Trigger data export. Returns async job id for getting status and results"""

packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import logging
3+
from dataclasses import dataclass
34
from datetime import timedelta
45
from typing import TYPE_CHECKING, Final
56

@@ -10,35 +11,36 @@
1011
ExecutionMetadata,
1112
OwnerMetadata,
1213
Task,
13-
TaskInfoStore,
1414
TaskKey,
15+
TaskStore,
1516
)
1617
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1718

18-
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
19+
_CELERY_TASK_PREFIX: Final[str] = "celery-task-"
1920
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
2021
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2122
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
2223
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
2324

25+
2426
_logger = logging.getLogger(__name__)
2527

2628

27-
def _build_key(task_key: TaskKey) -> str:
28-
return _CELERY_TASK_INFO_PREFIX + task_key
29+
def _build_redis_task_key(task_key: TaskKey) -> str:
30+
return _CELERY_TASK_PREFIX + task_key
2931

3032

31-
class RedisTaskInfoStore:
32-
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
33-
self._redis_client_sdk = redis_client_sdk
33+
@dataclass(frozen=True)
34+
class RedisTaskStore:
35+
_redis_client_sdk: RedisClientSDK
3436

3537
async def create_task(
3638
self,
3739
task_key: TaskKey,
3840
execution_metadata: ExecutionMetadata,
3941
expiry: timedelta,
4042
) -> None:
41-
redis_key = _build_key(task_key)
43+
redis_key = _build_redis_task_key(task_key)
4244
await handle_redis_returns_union_types(
4345
self._redis_client_sdk.redis.hset(
4446
name=redis_key,
@@ -54,7 +56,8 @@ async def create_task(
5456
async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None:
5557
raw_result = await handle_redis_returns_union_types(
5658
self._redis_client_sdk.redis.hget(
57-
_build_key(task_key), _CELERY_TASK_METADATA_KEY
59+
_build_redis_task_key(task_key),
60+
_CELERY_TASK_METADATA_KEY,
5861
)
5962
)
6063
if not raw_result:
@@ -73,7 +76,8 @@ async def get_task_metadata(self, task_key: TaskKey) -> ExecutionMetadata | None
7376
async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None:
7477
raw_result = await handle_redis_returns_union_types(
7578
self._redis_client_sdk.redis.hget(
76-
_build_key(task_key), _CELERY_TASK_PROGRESS_KEY
79+
_build_redis_task_key(task_key),
80+
_CELERY_TASK_PROGRESS_KEY,
7781
)
7882
)
7983
if not raw_result:
@@ -90,7 +94,7 @@ async def get_task_progress(self, task_key: TaskKey) -> ProgressReport | None:
9094
return None
9195

9296
async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
93-
search_key = _CELERY_TASK_INFO_PREFIX + owner_metadata.model_dump_task_key(
97+
search_key = _CELERY_TASK_PREFIX + owner_metadata.model_dump_task_key(
9498
task_uuid=WILDCARD
9599
)
96100

@@ -127,24 +131,28 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
127131
return tasks
128132

129133
async def remove_task(self, task_key: TaskKey) -> None:
130-
await self._redis_client_sdk.redis.delete(_build_key(task_key))
134+
await self._redis_client_sdk.redis.delete(
135+
_build_redis_task_key(task_key),
136+
)
131137

132138
async def set_task_progress(
133139
self, task_key: TaskKey, report: ProgressReport
134140
) -> None:
135141
await handle_redis_returns_union_types(
136142
self._redis_client_sdk.redis.hset(
137-
name=_build_key(task_key),
143+
name=_build_redis_task_key(task_key),
138144
key=_CELERY_TASK_PROGRESS_KEY,
139145
value=report.model_dump_json(),
140146
)
141147
)
142148

143149
async def task_exists(self, task_key: TaskKey) -> bool:
144-
n = await self._redis_client_sdk.redis.exists(_build_key(task_key))
150+
n = await self._redis_client_sdk.redis.exists(
151+
_build_redis_task_key(task_key),
152+
)
145153
assert isinstance(n, int) # nosec
146154
return n > 0
147155

148156

149157
if TYPE_CHECKING:
150-
_: type[TaskInfoStore] = RedisTaskInfoStore
158+
_: type[TaskStore] = RedisTaskStore

packages/celery-library/src/celery_library/errors.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import base64
22
import pickle
3+
from functools import wraps
34

5+
from celery.exceptions import CeleryError # type: ignore[import-untyped]
46
from common_library.errors_classes import OsparcErrorMixin
57

68

@@ -32,3 +34,18 @@ class TaskSubmissionError(OsparcErrorMixin, Exception):
3234

3335
class TaskNotFoundError(OsparcErrorMixin, Exception):
3436
msg_template = "Task with uuid '{task_uuid}' and owner_metadata '{owner_metadata}' was not found"
37+
38+
39+
class TaskManagerError(OsparcErrorMixin, Exception):
40+
msg_template = "An internal error occurred"
41+
42+
43+
def handle_celery_errors(func):
44+
@wraps(func)
45+
async def wrapper(*args, **kwargs):
46+
try:
47+
return await func(*args, **kwargs)
48+
except CeleryError as exc:
49+
raise TaskManagerError from exc
50+
51+
return wrapper

packages/celery-library/src/celery_library/task_manager.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
ExecutionMetadata,
1313
OwnerMetadata,
1414
Task,
15-
TaskInfoStore,
1615
TaskKey,
1716
TaskState,
1817
TaskStatus,
18+
TaskStore,
1919
TaskUUID,
2020
)
2121
from servicelib.celery.task_manager import TaskManager
2222
from servicelib.logging_utils import log_context
2323
from settings_library.celery import CelerySettings
2424

25-
from .errors import TaskNotFoundError, TaskSubmissionError
25+
from .errors import TaskNotFoundError, TaskSubmissionError, handle_celery_errors
2626

2727
_logger = logging.getLogger(__name__)
2828

@@ -35,8 +35,9 @@
3535
class CeleryTaskManager:
3636
_celery_app: Celery
3737
_celery_settings: CelerySettings
38-
_task_info_store: TaskInfoStore
38+
_task_info_store: TaskStore
3939

40+
@handle_celery_errors
4041
async def submit_task(
4142
self,
4243
execution_metadata: ExecutionMetadata,
@@ -85,6 +86,7 @@ async def submit_task(
8586

8687
return task_uuid
8788

89+
@handle_celery_errors
8890
async def cancel_task(
8991
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
9092
) -> None:
@@ -109,6 +111,7 @@ async def task_exists(self, task_key: TaskKey) -> bool:
109111
def _forget_task(self, task_key: TaskKey) -> None:
110112
self._celery_app.AsyncResult(task_key).forget()
111113

114+
@handle_celery_errors
112115
async def get_task_result(
113116
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
114117
) -> Any:
@@ -154,6 +157,7 @@ async def _get_task_progress_report(
154157
def _get_task_celery_state(self, task_key: TaskKey) -> TaskState:
155158
return TaskState(self._celery_app.AsyncResult(task_key).state)
156159

160+
@handle_celery_errors
157161
async def get_task_status(
158162
self, owner_metadata: OwnerMetadata, task_uuid: TaskUUID
159163
) -> TaskStatus:
@@ -177,6 +181,7 @@ async def get_task_status(
177181
),
178182
)
179183

184+
@handle_celery_errors
180185
async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
181186
with log_context(
182187
_logger,
@@ -185,6 +190,7 @@ async def list_tasks(self, owner_metadata: OwnerMetadata) -> list[Task]:
185190
):
186191
return await self._task_info_store.list_tasks(owner_metadata)
187192

193+
@handle_celery_errors
188194
async def set_task_progress(
189195
self, task_key: TaskKey, report: ProgressReport
190196
) -> None:

packages/celery-library/tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from celery.signals import worker_init, worker_shutdown
1717
from celery.worker.worker import WorkController
18-
from celery_library.backends.redis import RedisTaskInfoStore
18+
from celery_library.backends.redis import RedisTaskStore
1919
from celery_library.signals import on_worker_init, on_worker_shutdown
2020
from celery_library.task_manager import CeleryTaskManager
2121
from celery_library.types import register_celery_types
@@ -66,7 +66,7 @@ async def run_until_shutdown(
6666
self._task_manager = CeleryTaskManager(
6767
self._app,
6868
self._settings,
69-
RedisTaskInfoStore(redis_client_sdk),
69+
RedisTaskStore(redis_client_sdk),
7070
)
7171

7272
startup_completed_event.set()
@@ -156,11 +156,11 @@ async def mock_celery_app(celery_config: dict[str, Any]) -> Celery:
156156

157157

158158
@pytest.fixture
159-
async def celery_task_manager(
159+
async def task_manager(
160160
mock_celery_app: Celery,
161161
celery_settings: CelerySettings,
162162
use_in_memory_redis: RedisSettings,
163-
) -> AsyncIterator[CeleryTaskManager]:
163+
) -> AsyncIterator[TaskManager]:
164164
register_celery_types()
165165

166166
try:
@@ -173,7 +173,7 @@ async def celery_task_manager(
173173
yield CeleryTaskManager(
174174
mock_celery_app,
175175
celery_settings,
176-
RedisTaskInfoStore(redis_client_sdk),
176+
RedisTaskStore(redis_client_sdk),
177177
)
178178
finally:
179179
await redis_client_sdk.shutdown()

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ async def async_job(task: Task, task_key: TaskKey, action: Action, payload: Any)
138138

139139
@pytest.fixture
140140
async def register_rpc_routes(
141-
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, celery_task_manager: TaskManager
141+
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, task_manager: TaskManager
142142
) -> None:
143143
await async_jobs_rabbitmq_rpc_client.register_router(
144-
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
144+
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
145145
)
146146
await async_jobs_rabbitmq_rpc_client.register_router(
147-
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
147+
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
148148
)
149149

150150

0 commit comments

Comments
 (0)