Skip to content

Commit c579dad

Browse files
continue
1 parent 4260e22 commit c579dad

File tree

9 files changed

+52
-62
lines changed

9 files changed

+52
-62
lines changed

services/storage/src/simcore_service_storage/api/rest/_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
StorageQueryParamsBase,
3535
UploadLinks,
3636
)
37-
from ...modules.celery.client import CeleryTaskQueueClient
37+
from ...modules.celery.client import CeleryTaskClient
3838
from ...modules.celery.models import TaskUUID
3939
from ...simcore_s3_dsm import SimcoreS3DataManager
4040
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
@@ -270,7 +270,7 @@ async def abort_upload_file(
270270
status_code=status.HTTP_202_ACCEPTED,
271271
)
272272
async def complete_upload_file(
273-
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
273+
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
274274
query_params: Annotated[StorageQueryParamsBase, Depends()],
275275
location_id: LocationID,
276276
file_id: StorageFileID,
@@ -324,7 +324,7 @@ async def complete_upload_file(
324324
response_model=Envelope[FileUploadCompleteFutureResponse],
325325
)
326326
async def is_completed_upload_file(
327-
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
327+
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
328328
query_params: Annotated[StorageQueryParamsBase, Depends()],
329329
location_id: LocationID,
330330
file_id: StorageFileID,

services/storage/src/simcore_service_storage/api/rest/dependencies/celery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from servicelib.fastapi.dependencies import get_app
55

66
from ....modules.celery import get_celery_client as _get_celery_client_from_app
7-
from ....modules.celery.client import CeleryTaskQueueClient
7+
from ....modules.celery.client import CeleryTaskClient
88

99

1010
def get_celery_client(
1111
app: Annotated[FastAPI, Depends(get_app)],
12-
) -> CeleryTaskQueueClient:
12+
) -> CeleryTaskClient:
1313
return _get_celery_client_from_app(app)

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def result(
9797

9898
if _status.task_state == TaskState.ABORTED:
9999
raise JobAbortedError(job_id=job_id)
100-
if _status.task_state == TaskState.ERROR:
100+
if _status.task_state == TaskState.FAILURE:
101101
# fallback exception to report
102102
exc_type = type(_result).__name__
103103
exc_msg = f"{_result}"

services/storage/src/simcore_service_storage/modules/celery/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._celery_types import register_celery_types
1111
from ._common import create_app
1212
from .backends._redis import RedisTaskInfoStore
13-
from .client import CeleryTaskQueueClient
13+
from .client import CeleryTaskClient
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -28,7 +28,7 @@ async def on_startup() -> None:
2828
client_name=f"{APP_NAME}.celery_tasks",
2929
)
3030

31-
app.state.celery_client = CeleryTaskQueueClient(
31+
app.state.celery_client = CeleryTaskClient(
3232
celery_app,
3333
celery_settings,
3434
RedisTaskInfoStore(redis_client_sdk),
@@ -39,10 +39,10 @@ async def on_startup() -> None:
3939
app.add_event_handler("startup", on_startup)
4040

4141

42-
def get_celery_client(app: FastAPI) -> CeleryTaskQueueClient:
42+
def get_celery_client(app: FastAPI) -> CeleryTaskClient:
4343
assert hasattr(app.state, "celery_client") # nosec
4444
celery_client = app.state.celery_client
45-
assert isinstance(celery_client, CeleryTaskQueueClient)
45+
assert isinstance(celery_client, CeleryTaskClient)
4646
return celery_client
4747

4848

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from celery.exceptions import Ignore # type: ignore[import-untyped]
1616
from pydantic import NonNegativeInt
17+
from servicelib.async_utils import cancel_wait_task
1718

1819
from . import get_event_loop
1920
from .errors import encore_celery_transferrable_error
@@ -26,7 +27,8 @@
2627
_DEFAULT_MAX_RETRIES: Final[NonNegativeInt] = 3
2728
_DEFAULT_WAIT_BEFORE_RETRY: Final[timedelta] = timedelta(seconds=5)
2829
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = ()
29-
_DEFAULT_ABORT_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=0.5)
30+
_DEFAULT_ABORT_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=1)
31+
_DEFAULT_CANCEL_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=5)
3032

3133
T = TypeVar("T")
3234
P = ParamSpec("P")
@@ -60,10 +62,10 @@ async def run_task(task_id: TaskID) -> R:
6062
async def abort_monitor():
6163
while not main_task.done():
6264
if AbortableAsyncResult(task_id).is_aborted():
63-
_logger.warning(
64-
"Task %s aborted, cancelling main task", task_id
65+
await cancel_wait_task(
66+
main_task,
67+
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
6568
)
66-
main_task.cancel()
6769
return
6870
await asyncio.sleep(
6971
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import contextlib
21
import logging
32
from dataclasses import dataclass
4-
from typing import Any, Final
3+
from typing import Any
54
from uuid import uuid4
65

76
from celery import Celery # type: ignore[import-untyped]
@@ -10,7 +9,6 @@
109
)
1110
from common_library.async_tools import make_async
1211
from models_library.progress_bar import ProgressReport
13-
from pydantic import ValidationError
1412
from servicelib.logging_utils import log_context
1513
from settings_library.celery import CelerySettings
1614

@@ -27,23 +25,13 @@
2725

2826
_logger = logging.getLogger(__name__)
2927

30-
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
31-
"PENDING": TaskState.PENDING,
32-
"STARTED": TaskState.PENDING,
33-
"RETRY": TaskState.PENDING,
34-
"RUNNING": TaskState.RUNNING,
35-
"SUCCESS": TaskState.SUCCESS,
36-
"ABORTED": TaskState.ABORTED,
37-
"FAILURE": TaskState.ERROR,
38-
"ERROR": TaskState.ERROR,
39-
}
4028

4129
_MIN_PROGRESS_VALUE = 0.0
4230
_MAX_PROGRESS_VALUE = 1.0
4331

4432

4533
@dataclass
46-
class CeleryTaskQueueClient:
34+
class CeleryTaskClient:
4735
_celery_app: Celery
4836
_celery_settings: CelerySettings
4937
_task_store: TaskInfoStore
@@ -92,11 +80,6 @@ async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> No
9280
task_id = build_task_id(task_context, task_uuid)
9381
await self._abort_task(task_id)
9482

95-
@make_async()
96-
def _get_result(self, task_context: TaskContext, task_uuid: TaskUUID) -> Any:
97-
task_id = build_task_id(task_context, task_uuid)
98-
return self._celery_app.AsyncResult(task_id).result
99-
10083
async def get_task_result(
10184
self, task_context: TaskContext, task_uuid: TaskUUID
10285
) -> Any:
@@ -114,28 +97,31 @@ async def get_task_result(
11497
await self._task_store.remove(task_id)
11598
return result
11699

117-
@staticmethod
118-
async def _get_progress_report(state, result) -> ProgressReport:
119-
if result and state == TaskState.RUNNING:
120-
with contextlib.suppress(ValidationError):
121-
# avoids exception if result is not a ProgressReport (or overwritten by a Celery's state update)
122-
return ProgressReport.model_validate(result)
100+
async def _get_progress_report(
101+
self, task_id: TaskID, state: TaskState
102+
) -> ProgressReport:
103+
if state in (TaskState.STARTED, TaskState.RETRY):
104+
progress = await self._task_store.get_progress(task_id)
105+
if progress is not None:
106+
return progress
123107
if state in (
124-
TaskState.ABORTED,
125-
TaskState.ERROR,
126108
TaskState.SUCCESS,
109+
TaskState.ABORTED,
110+
TaskState.FAILURE,
127111
):
128112
return ProgressReport(
129113
actual_value=_MAX_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
130114
)
115+
116+
# task is pending
131117
return ProgressReport(
132118
actual_value=_MIN_PROGRESS_VALUE, total=_MAX_PROGRESS_VALUE
133119
)
134120

135121
@make_async()
136122
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
137123
task_id = build_task_id(task_context, task_uuid)
138-
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]
124+
return TaskState(self._celery_app.AsyncResult(task_id).state)
139125

140126
async def get_task_status(
141127
self, task_context: TaskContext, task_uuid: TaskUUID
@@ -146,11 +132,11 @@ async def get_task_status(
146132
msg=f"Getting task status: {task_context=} {task_uuid=}",
147133
):
148134
task_state = await self._get_state(task_context, task_uuid)
149-
result = await self._get_result(task_context, task_uuid)
135+
task_id = build_task_id(task_context, task_uuid)
150136
return TaskStatus(
151137
task_uuid=task_uuid,
152138
task_state=task_state,
153-
progress_report=await self._get_progress_report(task_state, result),
139+
progress_report=await self._get_progress_report(task_id, task_state),
154140
)
155141

156142
async def get_task_uuids(self, task_context: TaskContext) -> set[TaskUUID]:

services/storage/src/simcore_service_storage/modules/celery/models.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import timedelta
2-
from enum import StrEnum, auto
2+
from enum import StrEnum
33
from typing import Any, Final, Protocol, TypeAlias
44
from uuid import UUID
55

@@ -26,11 +26,12 @@ def build_task_id(task_context: TaskContext, task_uuid: TaskUUID) -> TaskID:
2626

2727

2828
class TaskState(StrEnum):
29-
PENDING = auto()
30-
RUNNING = auto()
31-
SUCCESS = auto()
32-
ERROR = auto()
33-
ABORTED = auto()
29+
PENDING = "PENDING"
30+
STARTED = "STARTED"
31+
RETRY = "RETRY"
32+
SUCCESS = "SUCCESS"
33+
FAILURE = "FAILURE"
34+
ABORTED = "ABORTED"
3435

3536

3637
class TasksQueue(StrEnum):
@@ -43,15 +44,15 @@ class TaskMetadata(BaseModel):
4344
queue: TasksQueue = TasksQueue.DEFAULT
4445

4546

46-
_TASK_DONE = {TaskState.SUCCESS, TaskState.ERROR, TaskState.ABORTED}
47+
_TASK_DONE = {TaskState.SUCCESS, TaskState.FAILURE, TaskState.ABORTED}
4748

4849

4950
class TaskInfoStore(Protocol):
5051
async def exists(self, task_id: TaskID) -> bool: ...
5152

52-
async def get_progress(self, task_id: TaskID) -> ProgressReport | None: ...
53+
async def get_progress(self, task_id: TaskID) -> ProgressReport: ...
5354

54-
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None: ...
55+
async def get_metadata(self, task_id: TaskID) -> TaskMetadata: ...
5556

5657
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]: ...
5758

services/storage/src/simcore_service_storage/modules/celery/signals.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
from ...core.application import create_app
1616
from ...core.settings import ApplicationSettings
17-
from ...modules.celery.backends._redis import RedisTaskInfoStore, set_event_loop
18-
from ...modules.celery.utils import (
17+
from . import set_event_loop
18+
from .backends._redis import RedisTaskInfoStore
19+
from .utils import (
1920
get_fastapi_app,
2021
set_celery_worker,
2122
set_fastapi_app,
2223
)
23-
from ...modules.celery.worker import CeleryTaskWorker
24+
from .worker import CeleryTaskWorker
2425

2526
_logger = logging.getLogger(__name__)
2627

services/storage/tests/unit/test_modules_celery.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from servicelib.logging_utils import log_context
2020
from simcore_service_storage.modules.celery import get_celery_client, get_event_loop
2121
from simcore_service_storage.modules.celery._task import register_task
22-
from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient
22+
from simcore_service_storage.modules.celery.client import CeleryTaskClient
2323
from simcore_service_storage.modules.celery.errors import TransferrableCeleryError
2424
from simcore_service_storage.modules.celery.models import (
2525
TaskContext,
@@ -42,7 +42,7 @@
4242
def celery_client(
4343
initialized_app: FastAPI,
4444
with_storage_celery_worker: CeleryTaskWorker,
45-
) -> CeleryTaskQueueClient:
45+
) -> CeleryTaskClient:
4646
return get_celery_client(initialized_app)
4747

4848

@@ -106,7 +106,7 @@ def _(celery_app: Celery) -> None:
106106

107107

108108
async def test_submitting_task_calling_async_function_results_with_success_state(
109-
celery_client: CeleryTaskQueueClient,
109+
celery_client: CeleryTaskClient,
110110
):
111111
task_context = TaskContext(user_id=42)
112112

@@ -134,7 +134,7 @@ async def test_submitting_task_calling_async_function_results_with_success_state
134134

135135

136136
async def test_submitting_task_with_failure_results_with_error(
137-
celery_client: CeleryTaskQueueClient,
137+
celery_client: CeleryTaskClient,
138138
):
139139
task_context = TaskContext(user_id=42)
140140

@@ -157,7 +157,7 @@ async def test_submitting_task_with_failure_results_with_error(
157157

158158

159159
async def test_aborting_task_results_with_aborted_state(
160-
celery_client: CeleryTaskQueueClient,
160+
celery_client: CeleryTaskClient,
161161
):
162162
task_context = TaskContext(user_id=42)
163163

@@ -183,7 +183,7 @@ async def test_aborting_task_results_with_aborted_state(
183183

184184

185185
async def test_listing_task_uuids_contains_submitted_task(
186-
celery_client: CeleryTaskQueueClient,
186+
celery_client: CeleryTaskClient,
187187
):
188188
task_context = TaskContext(user_id=42)
189189

0 commit comments

Comments
 (0)