Skip to content

Commit e7c34b1

Browse files
authored
Merge branch 'master' into feature/project-to-pipeline
2 parents 67420ca + aac0abd commit e7c34b1

File tree

19 files changed

+207
-141
lines changed

19 files changed

+207
-141
lines changed

packages/service-library/src/servicelib/bytes_iters/_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ async def with_progress_bytes_iter(
1414
self, progress_bar: ProgressBarData
1515
) -> BytesIter:
1616
async for chunk in self.bytes_iter_callable():
17-
await progress_bar.update(len(chunk))
17+
if progress_bar.is_running():
18+
await progress_bar.update(len(chunk))
1819
yield chunk

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ async def _task_progress_cb(
2626
) -> None:
2727
worker = get_celery_worker(task.app)
2828
assert task.name # nosec
29-
worker.set_task_progress(
30-
task_name=task.name,
29+
await worker.set_progress(
3130
task_id=task_id,
3231
report=report,
3332
)
@@ -88,7 +87,7 @@ async def export_data(
8887

8988
async def _progress_cb(report: ProgressReport) -> None:
9089
assert task.name # nosec
91-
get_celery_worker(task.app).set_task_progress(task.name, task_id, report)
90+
await get_celery_worker(task.app).set_progress(task_id, report)
9291
_logger.debug("'%s' progress %s", task_id, report.percent_value)
9392

9493
async with ProgressBarData(

services/storage/src/simcore_service_storage/api/rest/_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
StorageQueryParamsBase,
3535
UploadLinks,
3636
)
37-
from ...modules.celery.client import CeleryTaskQueueClient
37+
from ...modules.celery.client import CeleryTaskClient
3838
from ...modules.celery.models import TaskUUID
3939
from ...simcore_s3_dsm import SimcoreS3DataManager
4040
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
@@ -270,7 +270,7 @@ async def abort_upload_file(
270270
status_code=status.HTTP_202_ACCEPTED,
271271
)
272272
async def complete_upload_file(
273-
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
273+
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
274274
query_params: Annotated[StorageQueryParamsBase, Depends()],
275275
location_id: LocationID,
276276
file_id: StorageFileID,
@@ -324,7 +324,7 @@ async def complete_upload_file(
324324
response_model=Envelope[FileUploadCompleteFutureResponse],
325325
)
326326
async def is_completed_upload_file(
327-
celery_client: Annotated[CeleryTaskQueueClient, Depends(get_celery_client)],
327+
celery_client: Annotated[CeleryTaskClient, Depends(get_celery_client)],
328328
query_params: Annotated[StorageQueryParamsBase, Depends()],
329329
location_id: LocationID,
330330
file_id: StorageFileID,

services/storage/src/simcore_service_storage/api/rest/dependencies/celery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from servicelib.fastapi.dependencies import get_app
55

66
from ....modules.celery import get_celery_client as _get_celery_client_from_app
7-
from ....modules.celery.client import CeleryTaskQueueClient
7+
from ....modules.celery.client import CeleryTaskClient
88

99

1010
def get_celery_client(
1111
app: Annotated[FastAPI, Depends(get_app)],
12-
) -> CeleryTaskQueueClient:
12+
) -> CeleryTaskClient:
1313
return _get_celery_client_from_app(app)

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def result(
9797

9898
if _status.task_state == TaskState.ABORTED:
9999
raise JobAbortedError(job_id=job_id)
100-
if _status.task_state == TaskState.ERROR:
100+
if _status.task_state == TaskState.FAILURE:
101101
# fallback exception to report
102102
exc_type = type(_result).__name__
103103
exc_msg = f"{_result}"

services/storage/src/simcore_service_storage/modules/celery/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from ...core.settings import get_application_settings
1010
from ._celery_types import register_celery_types
1111
from ._common import create_app
12-
from .backends._redis import RedisTaskMetadataStore
13-
from .client import CeleryTaskQueueClient
12+
from .backends._redis import RedisTaskInfoStore
13+
from .client import CeleryTaskClient
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -28,21 +28,21 @@ async def on_startup() -> None:
2828
client_name=f"{APP_NAME}.celery_tasks",
2929
)
3030

31-
app.state.celery_client = CeleryTaskQueueClient(
31+
app.state.celery_client = CeleryTaskClient(
3232
celery_app,
3333
celery_settings,
34-
RedisTaskMetadataStore(redis_client_sdk),
34+
RedisTaskInfoStore(redis_client_sdk),
3535
)
3636

3737
register_celery_types()
3838

3939
app.add_event_handler("startup", on_startup)
4040

4141

42-
def get_celery_client(app: FastAPI) -> CeleryTaskQueueClient:
42+
def get_celery_client(app: FastAPI) -> CeleryTaskClient:
4343
assert hasattr(app.state, "celery_client") # nosec
4444
celery_client = app.state.celery_client
45-
assert isinstance(celery_client, CeleryTaskQueueClient)
45+
assert isinstance(celery_client, CeleryTaskClient)
4646
return celery_client
4747

4848

services/storage/src/simcore_service_storage/modules/celery/_task.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,36 @@
77
from typing import Any, Concatenate, Final, ParamSpec, TypeVar, overload
88

99
from celery import Celery # type: ignore[import-untyped]
10-
from celery.contrib.abortable import AbortableTask # type: ignore[import-untyped]
10+
from celery.contrib.abortable import ( # type: ignore[import-untyped]
11+
AbortableAsyncResult,
12+
AbortableTask,
13+
)
14+
from celery.exceptions import Ignore # type: ignore[import-untyped]
1115
from pydantic import NonNegativeInt
16+
from servicelib.async_utils import cancel_wait_task
1217

1318
from . import get_event_loop
1419
from .errors import encore_celery_transferrable_error
15-
from .models import TaskId
20+
from .models import TaskID, TaskId
1621
from .utils import get_fastapi_app
1722

1823
_logger = logging.getLogger(__name__)
1924

2025
_DEFAULT_TASK_TIMEOUT: Final[timedelta | None] = None
2126
_DEFAULT_MAX_RETRIES: Final[NonNegativeInt] = 3
2227
_DEFAULT_WAIT_BEFORE_RETRY: Final[timedelta] = timedelta(seconds=5)
23-
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = tuple()
24-
28+
_DEFAULT_DONT_AUTORETRY_FOR: Final[tuple[type[Exception], ...]] = ()
29+
_DEFAULT_ABORT_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=1)
30+
_DEFAULT_CANCEL_TASK_TIMEOUT: Final[timedelta] = timedelta(seconds=5)
2531

2632
T = TypeVar("T")
2733
P = ParamSpec("P")
2834
R = TypeVar("R")
2935

3036

37+
class TaskAbortedError(Exception): ...
38+
39+
3140
def _async_task_wrapper(
3241
app: Celery,
3342
) -> Callable[
@@ -40,11 +49,46 @@ def decorator(
4049
@wraps(coro)
4150
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
4251
fastapi_app = get_fastapi_app(app)
43-
_logger.debug("task id: %s", task.request.id)
4452
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
4553
assert task.request.id is not None # nosec
54+
55+
async def run_task(task_id: TaskID) -> R:
56+
try:
57+
async with asyncio.TaskGroup() as tg:
58+
main_task = tg.create_task(
59+
coro(task, task_id, *args, **kwargs),
60+
)
61+
62+
async def abort_monitor():
63+
abortable_result = AbortableAsyncResult(task_id, app=app)
64+
while not main_task.done():
65+
if abortable_result.is_aborted():
66+
await cancel_wait_task(
67+
main_task,
68+
max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(),
69+
)
70+
raise TaskAbortedError
71+
await asyncio.sleep(
72+
_DEFAULT_ABORT_TASK_TIMEOUT.total_seconds()
73+
)
74+
75+
tg.create_task(abort_monitor())
76+
77+
return main_task.result()
78+
except BaseExceptionGroup as eg:
79+
task_aborted_errors, other_errors = eg.split(TaskAbortedError)
80+
81+
if task_aborted_errors:
82+
assert task_aborted_errors is not None # nosec
83+
assert len(task_aborted_errors.exceptions) == 1 # nosec
84+
raise task_aborted_errors.exceptions[0] from eg
85+
86+
assert other_errors is not None # nosec
87+
assert len(other_errors.exceptions) == 1 # nosec
88+
raise other_errors.exceptions[0] from eg
89+
4690
return asyncio.run_coroutine_threadsafe(
47-
coro(task, task.request.id, *args, **kwargs),
91+
run_task(task.request.id),
4892
get_event_loop(fastapi_app),
4993
).result()
5094

@@ -68,6 +112,9 @@ def decorator(
68112
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
69113
try:
70114
return func(task, *args, **kwargs)
115+
except TaskAbortedError as exc:
116+
_logger.warning("Task %s was cancelled", task.request.id)
117+
raise Ignore from exc
71118
except Exception as exc:
72119
if isinstance(exc, dont_autoretry_for):
73120
_logger.debug("Not retrying for exception %s", type(exc).__name__)

services/storage/src/simcore_service_storage/modules/celery/backends/_redis.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,26 @@
33
from typing import Final
44

55
from celery.result import AsyncResult # type: ignore[import-untyped]
6+
from models_library.progress_bar import ProgressReport
67
from servicelib.redis._client import RedisClientSDK
78

89
from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix
910

10-
_CELERY_TASK_METADATA_PREFIX: Final[str] = "celery-task-metadata-"
11+
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
1112
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
1213
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
1314
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
15+
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
16+
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
1417

1518
_logger = logging.getLogger(__name__)
1619

1720

1821
def _build_key(task_id: TaskID) -> str:
19-
return _CELERY_TASK_METADATA_PREFIX + task_id
22+
return _CELERY_TASK_INFO_PREFIX + task_id
2023

2124

22-
class RedisTaskMetadataStore:
25+
class RedisTaskInfoStore:
2326
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
2427
self._redis_client_sdk = redis_client_sdk
2528

@@ -28,13 +31,17 @@ async def exists(self, task_id: TaskID) -> bool:
2831
assert isinstance(n, int) # nosec
2932
return n > 0
3033

31-
async def get(self, task_id: TaskID) -> TaskMetadata | None:
32-
result = await self._redis_client_sdk.redis.get(_build_key(task_id))
34+
async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
35+
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
3336
return TaskMetadata.model_validate_json(result) if result else None
3437

38+
async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
39+
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
40+
return ProgressReport.model_validate_json(result) if result else None
41+
3542
async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
3643
search_key = (
37-
_CELERY_TASK_METADATA_PREFIX
44+
_CELERY_TASK_INFO_PREFIX
3845
+ build_task_id_prefix(task_context)
3946
+ _CELERY_TASK_ID_KEY_SEPARATOR
4047
)
@@ -55,11 +62,22 @@ async def remove(self, task_id: TaskID) -> None:
5562
await self._redis_client_sdk.redis.delete(_build_key(task_id))
5663
AsyncResult(task_id).forget()
5764

58-
async def set(
65+
async def set_metadata(
5966
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
6067
) -> None:
61-
await self._redis_client_sdk.redis.set(
68+
await self._redis_client_sdk.redis.hset(
69+
name=_build_key(task_id),
70+
key=_CELERY_TASK_METADATA_KEY,
71+
value=task_metadata.model_dump_json(),
72+
) # type: ignore
73+
await self._redis_client_sdk.redis.expire(
6274
_build_key(task_id),
63-
task_metadata.model_dump_json(),
64-
ex=expiry,
75+
expiry,
6576
)
77+
78+
async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
79+
await self._redis_client_sdk.redis.hset(
80+
name=_build_key(task_id),
81+
key=_CELERY_TASK_PROGRESS_KEY,
82+
value=report.model_dump_json(),
83+
) # type: ignore

0 commit comments

Comments
 (0)