Skip to content

Commit 3a7f7e9

Browse files
improve error handling
1 parent 99cb6b7 commit 3a7f7e9

File tree

7 files changed

+95
-38
lines changed

7 files changed

+95
-38
lines changed

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def get_result(
5656
assert job_id # nosec
5757
assert job_id_data # nosec
5858

59-
result = await get_celery_client(app).get_result(
59+
result = await get_celery_client(app).get_task_result(
6060
task_context=job_id_data.model_dump(),
6161
task_uuid=job_id,
6262
)

services/storage/src/simcore_service_storage/core/application.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,16 @@ def create_app(settings: ApplicationSettings) -> FastAPI:
8383
setup_client_session(app)
8484

8585
setup_rabbitmq(app)
86-
setup_rpc_api_routes(app)
86+
if not settings.STORAGE_WORKER_MODE:
87+
setup_rpc_api_routes(app)
8788
setup_rest_api_long_running_tasks_for_uploads(app)
8889
setup_rest_api_routes(app, API_VTAG)
8990
set_exception_handlers(app)
9091

9192
setup_redis(app)
9293

9394
setup_dsm(app)
94-
if settings.STORAGE_CLEANER_INTERVAL_S:
95+
if settings.STORAGE_CLEANER_INTERVAL_S and not settings.STORAGE_WORKER_MODE:
9596
setup_dsm_cleaner(app)
9697

9798
if settings.STORAGE_PROFILING:

services/storage/src/simcore_service_storage/modules/celery/_common.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from collections.abc import Callable
2+
from functools import wraps
13
import logging
4+
import traceback
25

3-
from celery import Celery
6+
from celery import Celery, Task
7+
from celery.exceptions import Ignore
8+
from celery.contrib.abortable import AbortableTask
49
from settings_library.celery import CelerySettings
510
from settings_library.redis import RedisDatabase
611

12+
from .models import TaskError
13+
714
_logger = logging.getLogger(__name__)
815

916

@@ -18,5 +25,36 @@ def create_app(celery_settings: CelerySettings) -> Celery:
1825
)
1926
app.conf.result_expires = celery_settings.CELERY_RESULT_EXPIRES
2027
app.conf.result_extended = True # original args are included in the results
28+
app.conf.result_serializer = "json"
2129
app.conf.task_track_started = True
2230
return app
31+
32+
33+
def error_handling(func: Callable):
34+
@wraps(func)
35+
def wrapper(task: Task, *args, **kwargs):
36+
try:
37+
return func(task, *args, **kwargs)
38+
except Exception as exc:
39+
exc_type = type(exc).__name__
40+
exc_message = f"{exc}"
41+
exc_traceback = traceback.format_exc().split('\n')
42+
43+
task.update_state(
44+
state="ERROR",
45+
meta=TaskError(
46+
exc_type=exc_type,
47+
exc_msg=exc_message,
48+
).model_dump(mode="json"),
49+
traceback=exc_traceback
50+
)
51+
raise Ignore from exc
52+
return wrapper
53+
54+
55+
def define_task(app: Celery, fn: Callable, task_name: str | None = None):
56+
app.task(
57+
name=task_name or fn.__name__,
58+
bind=True,
59+
base=AbortableTask,
60+
)(error_handling(fn))

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import contextlib
22
import logging
3-
from typing import Any, Final
3+
from typing import Any, Final, Type
44
from uuid import uuid4
55

66
from celery import Celery
77
from celery.contrib.abortable import AbortableAsyncResult
88
from common_library.async_tools import make_async
99
from models_library.progress_bar import ProgressReport
10-
from pydantic import ValidationError
10+
from pydantic import TypeAdapter, ValidationError
1111
from servicelib.logging_utils import log_context
1212

13-
from .models import TaskContext, TaskID, TaskState, TaskStatus, TaskUUID
13+
from .models import TaskContext, TaskError, TaskID, TaskResult, TaskState, TaskStatus, TaskUUID
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -28,20 +28,23 @@
2828
"RUNNING": TaskState.RUNNING,
2929
"SUCCESS": TaskState.SUCCESS,
3030
"ABORTED": TaskState.ABORTED,
31-
"FAILURE": TaskState.FAILURE,
31+
"FAILURE": TaskState.ERROR,
32+
"ERROR": TaskState.ERROR,
3233
}
34+
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
35+
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
3336

3437

3538
def _build_context_prefix(task_context: TaskContext) -> list[str]:
3639
return [f"{task_context[key]}" for key in sorted(task_context)]
3740

3841

3942
def _build_task_id_prefix(task_context: TaskContext) -> str:
40-
return ":".join(_build_context_prefix(task_context))
43+
return _CELERY_TASK_ID_KEY_SEPARATOR.join(_build_context_prefix(task_context))
4144

4245

4346
def _build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
44-
return ":".join([_build_task_id_prefix(task_context), f"{task_uuid}"])
47+
return _CELERY_TASK_ID_KEY_SEPARATOR.join([_build_task_id_prefix(task_context), f"{task_uuid}"])
4548

4649

4750
class CeleryTaskQueueClient:
@@ -76,9 +79,11 @@ def abort_task( # pylint: disable=R6301
7679
AbortableAsyncResult(task_id).abort()
7780

7881
@make_async()
79-
def get_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
82+
def get_task_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskResult:
8083
task_id = _build_task_id(task_context, task_uuid)
81-
return self._celery_app.AsyncResult(task_id).result
84+
return TypeAdapter(TaskResult).validate_python(
85+
self._celery_app.AsyncResult(task_id).result
86+
)
8287

8388
def _get_progress_report(
8489
self, task_context: TaskContext, task_uuid: TaskUUID
@@ -91,7 +96,7 @@ def _get_progress_report(
9196
return ProgressReport.model_validate(result)
9297
if state in (
9398
TaskState.ABORTED,
94-
TaskState.FAILURE,
99+
TaskState.ERROR,
95100
TaskState.SUCCESS,
96101
):
97102
return ProgressReport(actual_value=100.0)
@@ -113,12 +118,12 @@ def get_task_status(
113118

114119
def _get_completed_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
115120
search_key = (
116-
_CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context) + "*"
121+
_CELERY_TASK_META_PREFIX + _build_task_id_prefix(task_context)
117122
)
118123
redis = self._celery_app.backend.client
119-
if hasattr(redis, "keys") and (keys := redis.keys(search_key)):
124+
if hasattr(redis, "keys") and (keys := redis.keys(search_key + "*")):
120125
return {
121-
TaskUUID(f"{key}".removeprefix(_CELERY_TASK_META_PREFIX))
126+
TaskUUID(f"{key.decode(_CELERY_TASK_ID_KEY_ENCODING).removeprefix(search_key + _CELERY_TASK_ID_KEY_SEPARATOR)}")
122127
for key in keys
123128
}
124129
return set()

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from enum import StrEnum, auto
2-
from typing import Any, Final, Self, TypeAlias
2+
from typing import Annotated, Any, Final, Self, TypeAlias
33
from uuid import UUID
44

55
from models_library.progress_bar import ProgressReport
6-
from pydantic import BaseModel, model_validator
6+
from pydantic import BaseModel, Field, model_validator
77

88
TaskContext: TypeAlias = dict[str, Any]
99
TaskID: TypeAlias = str
@@ -17,11 +17,11 @@ class TaskState(StrEnum):
1717
PENDING = auto()
1818
RUNNING = auto()
1919
SUCCESS = auto()
20-
FAILURE = auto()
20+
ERROR = auto()
2121
ABORTED = auto()
2222

2323

24-
_TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED}
24+
_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}
2525

2626

2727
class TaskStatus(BaseModel):
@@ -42,11 +42,21 @@ def _check_consistency(self) -> Self:
4242
TaskState.RUNNING: _MIN_PROGRESS <= value <= _MAX_PROGRESS,
4343
TaskState.SUCCESS: value == _MAX_PROGRESS,
4444
TaskState.ABORTED: value == _MAX_PROGRESS,
45-
TaskState.FAILURE: value == _MAX_PROGRESS,
45+
TaskState.ERROR: value == _MAX_PROGRESS,
4646
}
4747

4848
if not valid_states.get(self.task_state, True):
4949
msg = f"Inconsistent progress actual value for state={self.task_state}: {value}"
5050
raise ValueError(msg)
5151

5252
return self
53+
54+
55+
class TaskError(BaseModel):
56+
exc_type: str
57+
exc_msg: str
58+
59+
60+
TaskResult: TypeAlias = Annotated[
61+
TaskError | Any, Field(union_mode="left_to_right")
62+
]

services/storage/src/simcore_service_storage/modules/celery/worker_main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44

5-
from celery.contrib.abortable import AbortableTask
65
from celery.signals import worker_init, worker_shutdown
76
from servicelib.logging_utils import config_all_loggers
87
from simcore_service_storage.modules.celery.signals import (
@@ -11,8 +10,8 @@
1110
)
1211

1312
from ...core.settings import ApplicationSettings
14-
from ._common import create_app as create_celery_app
15-
from .tasks import export_data
13+
from ._common import create_app as create_celery_app, define_task
14+
from .tasks import export_data, export_data_with_error
1615

1716
_settings = ApplicationSettings.create_from_envs()
1817

@@ -30,4 +29,6 @@
3029
app = create_celery_app(_settings.STORAGE_CELERY)
3130
worker_init.connect(on_worker_init)
3231
worker_shutdown.connect(on_worker_shutdown)
33-
app.task(name="export_data", bind=True, base=AbortableTask)(export_data)
32+
33+
define_task(app, export_data)
34+
define_task(app, export_data_with_error)

services/storage/tests/unit/modules/celery/test_celery.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from collections.abc import Callable
55
from random import randint
66

7+
from pydantic import ValidationError
78
import pytest
89
from celery import Celery, Task
910
from celery.contrib.abortable import AbortableTask
1011
from models_library.progress_bar import ProgressReport
1112
from servicelib.logging_utils import log_context
1213
from simcore_service_storage.modules.celery import get_event_loop
14+
from simcore_service_storage.modules.celery._common import define_task
1315
from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient
14-
from simcore_service_storage.modules.celery.models import TaskContext, TaskState
16+
from simcore_service_storage.modules.celery.models import TaskContext, TaskError, TaskState
1517
from simcore_service_storage.modules.celery.utils import (
1618
get_celery_worker,
1719
get_fastapi_app,
@@ -69,11 +71,9 @@ def dreamer_task(task: AbortableTask) -> list[int]:
6971
@pytest.fixture
7072
def register_celery_tasks() -> Callable[[Celery], None]:
7173
def _(celery_app: Celery) -> None:
72-
celery_app.task(name="sync_archive", bind=True)(sync_archive)
73-
celery_app.task(name="failure_task", bind=True)(failure_task)
74-
celery_app.task(name="dreamer_task", base=AbortableTask, bind=True)(
75-
dreamer_task
76-
)
74+
define_task(celery_app, sync_archive)
75+
define_task(celery_app, failure_task)
76+
define_task(celery_app, dreamer_task)
7777

7878
return _
7979

@@ -99,6 +99,9 @@ async def test_sumitting_task_calling_async_function_results_with_success_state(
9999
progress = await celery_client.get_task_status(task_context, task_uuid)
100100
assert progress.task_state == TaskState.SUCCESS
101101

102+
assert (
103+
await celery_client.get_task_result(task_context, task_uuid)
104+
) == "archive.zip"
102105
assert (
103106
await celery_client.get_task_status(task_context, task_uuid)
104107
).task_state == TaskState.SUCCESS
@@ -113,20 +116,19 @@ async def test_submitting_task_with_failure_results_with_error(
113116
task_uuid = await celery_client.send_task("failure_task", task_context=task_context)
114117

115118
for attempt in Retrying(
116-
retry=retry_if_exception_type(AssertionError),
119+
retry=retry_if_exception_type((AssertionError, ValidationError)),
117120
wait=wait_fixed(1),
118121
stop=stop_after_delay(30),
119122
):
120123
with attempt:
121-
result = await celery_client.get_result(task_context, task_uuid)
122-
assert isinstance(result, ValueError)
124+
result = await celery_client.get_task_result(task_context, task_uuid)
125+
assert isinstance(result, TaskError)
123126

124127
assert (
125128
await celery_client.get_task_status(task_context, task_uuid)
126-
).task_state == TaskState.FAILURE
127-
result = await celery_client.get_result(task_context, task_uuid)
128-
assert isinstance(result, ValueError)
129-
assert f"{result}" == "my error here"
129+
).task_state == TaskState.ERROR
130+
result = await celery_client.get_task_result(task_context, task_uuid)
131+
assert f"{result.exc_msg}" == "my error here"
130132

131133

132134
@pytest.mark.usefixtures("celery_worker")

0 commit comments

Comments
 (0)