Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2ebf1b7
initial commit
giancarloromeo Apr 16, 2025
347e323
fix keys generation
giancarloromeo Apr 16, 2025
792de64
update signature
giancarloromeo Apr 16, 2025
bf3e5b7
continue
giancarloromeo Apr 16, 2025
c390a3e
return task name
giancarloromeo Apr 16, 2025
0c10248
fix methods
giancarloromeo Apr 16, 2025
9bfad29
add task_name
giancarloromeo Apr 16, 2025
17a896d
fix test
giancarloromeo Apr 16, 2025
d02fe75
update name
giancarloromeo Apr 16, 2025
2416227
fix tests
giancarloromeo Apr 16, 2025
7ce6297
Merge branch 'master' into is7528/add-name-when-listing-tasks
giancarloromeo Apr 16, 2025
5bbcefc
typecheck
giancarloromeo Apr 16, 2025
eda9e7f
Merge branch 'is7528/add-name-when-listing-tasks' of github.com:gianc…
giancarloromeo Apr 16, 2025
bf6af2d
move forget
giancarloromeo Apr 16, 2025
5e70c09
legacy
giancarloromeo Apr 16, 2025
e612858
Merge branch 'master' into is7528/add-name-when-listing-tasks
giancarloromeo Apr 16, 2025
96cdf2a
fix test
giancarloromeo Apr 16, 2025
c1ec12b
Merge branch 'is7528/add-name-when-listing-tasks' of github.com:gianc…
giancarloromeo Apr 16, 2025
726f4ce
fix test
giancarloromeo Apr 16, 2025
417876f
fix async routine
giancarloromeo Apr 16, 2025
0cbec0d
rename
giancarloromeo Apr 16, 2025
53187d3
add legacy name
giancarloromeo Apr 16, 2025
2e96147
add check
giancarloromeo Apr 16, 2025
f941163
add exception handling
giancarloromeo Apr 16, 2025
939279d
update logger level
giancarloromeo Apr 16, 2025
2243c57
add validators
giancarloromeo Apr 17, 2025
584395c
Merge branch 'master' into is7528/add-name-when-listing-tasks
giancarloromeo Apr 17, 2025
0cccc5b
Merge branch 'master' into is7528/add-name-when-listing-tasks
giancarloromeo Apr 17, 2025
247538b
Merge branch 'master' into is7528/add-name-when-listing-tasks
odeimaiz Apr 22, 2025
0f3a393
Merge remote-tracking branch 'upstream/master' into is7528/add-name-w…
giancarloromeo Apr 22, 2025
442ac2f
Merge branch 'is7528/add-name-when-listing-tasks' of github.com:gianc…
giancarloromeo Apr 22, 2025
8ea6823
minors
giancarloromeo Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, TypeAlias
from uuid import UUID

from models_library.users import UserID
from pydantic import BaseModel

from ..products import ProductName
from ..progress_bar import ProgressReport
from ..users import UserID

AsyncJobId: TypeAlias = UUID
AsyncJobName: TypeAlias = str


class AsyncJobStatus(BaseModel):
Expand All @@ -21,6 +23,7 @@ class AsyncJobResult(BaseModel):

class AsyncJobGet(BaseModel):
job_id: AsyncJobId
job_name: AsyncJobName


class AsyncJobAbort(BaseModel):
Expand All @@ -31,5 +34,5 @@ class AsyncJobAbort(BaseModel):
class AsyncJobNameData(BaseModel):
"""Data for controlling access to an async job"""

product_name: ProductName
user_id: UserID
product_name: str
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def _task_progress_cb(
) -> None:
worker = get_celery_worker(task.app)
assert task.name # nosec
await worker.set_progress(
await worker.set_task_progress(
task_id=task_id,
report=report,
)
Expand Down Expand Up @@ -87,7 +87,7 @@ async def export_data(

async def _progress_cb(report: ProgressReport) -> None:
assert task.name # nosec
await get_celery_worker(task.app).set_progress(task_id, report)
await get_celery_worker(task.app).set_task_progress(task_id, report)
_logger.debug("'%s' progress %s", task_id, report.percent_value)

async with ProgressBarData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
UploadLinks,
)
from ...modules.celery.client import CeleryTaskClient
from ...modules.celery.models import TaskUUID
from ...modules.celery.models import TaskMetadata, TaskUUID
from ...simcore_s3_dsm import SimcoreS3DataManager
from .._worker_tasks._files import complete_upload_file as remote_complete_upload_file
from .dependencies.celery import get_celery_client
Expand Down Expand Up @@ -284,8 +284,10 @@ async def complete_upload_file(
user_id=query_params.user_id,
product_name=_UNDEFINED_PRODUCT_NAME_FOR_WORKER_TASKS, # NOTE: I would need to change the API here
)
task_uuid = await celery_client.send_task(
remote_complete_upload_file.__name__,
task_uuid = await celery_client.submit_task(
TaskMetadata(
name=remote_complete_upload_file.__name__,
),
task_context=async_job_name_data.model_dump(),
user_id=async_job_name_data.user_id,
location_id=location_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ async def list_jobs(
_ = filter_
assert app # nosec
try:
task_uuids = await get_celery_client(app).get_task_uuids(
tasks = await get_celery_client(app).list_tasks(
task_context=job_id_data.model_dump(),
)
except CeleryError as exc:
raise JobSchedulerError(exc=f"{exc}") from exc

return [AsyncJobGet(job_id=task_uuid) for task_uuid in task_uuids]
return [
AsyncJobGet(job_id=task.uuid, job_name=task.metadata.name) for task in tasks
]
19 changes: 13 additions & 6 deletions services/storage/src/simcore_service_storage/api/rpc/_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_celery_client
from ...modules.celery.models import TaskMetadata
from .._worker_tasks._paths import compute_path_size as remote_compute_path_size
from .._worker_tasks._paths import delete_paths as remote_delete_paths

Expand All @@ -24,15 +25,18 @@ async def compute_path_size(
location_id: LocationID,
path: Path,
) -> AsyncJobGet:
task_uuid = await get_celery_client(app).send_task(
remote_compute_path_size.__name__,
task_name = remote_compute_path_size.__name__
task_uuid = await get_celery_client(app).submit_task(
task_metadata=TaskMetadata(
name=task_name,
),
task_context=job_id_data.model_dump(),
user_id=job_id_data.user_id,
location_id=location_id,
path=path,
)

return AsyncJobGet(job_id=task_uuid)
return AsyncJobGet(job_id=task_uuid, job_name=task_name)


@router.expose(reraise_if_error_type=None)
Expand All @@ -42,11 +46,14 @@ async def delete_paths(
location_id: LocationID,
paths: set[Path],
) -> AsyncJobGet:
task_uuid = await get_celery_client(app).send_task(
remote_delete_paths.__name__,
task_name = remote_delete_paths.__name__
task_uuid = await get_celery_client(app).submit_task(
task_metadata=TaskMetadata(
name=task_name,
),
task_context=job_id_data.model_dump(),
user_id=job_id_data.user_id,
location_id=location_id,
paths=paths,
)
return AsyncJobGet(job_id=task_uuid)
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,32 @@ async def copy_folders_from_project(
job_id_data: AsyncJobNameData,
body: FoldersBody,
) -> AsyncJobGet:
task_uuid = await get_celery_client(app).send_task(
deep_copy_files_from_project.__name__,
task_name = deep_copy_files_from_project.__name__
task_uuid = await get_celery_client(app).submit_task(
task_metadata=TaskMetadata(
name=task_name,
),
task_context=job_id_data.model_dump(),
user_id=job_id_data.user_id,
body=body,
)

return AsyncJobGet(job_id=task_uuid)
return AsyncJobGet(job_id=task_uuid, job_name=task_name)


@router.expose()
async def start_export_data(
app: FastAPI, job_id_data: AsyncJobNameData, paths_to_export: list[PathToExport]
) -> AsyncJobGet:
task_uuid = await get_celery_client(app).send_task(
export_data.__name__,
task_context=job_id_data.model_dump(),
task_name = export_data.__name__
task_uuid = await get_celery_client(app).submit_task(
task_metadata=TaskMetadata(
name=task_name,
ephemeral=False,
queue=TasksQueue.CPU_BOUND,
),
task_context=job_id_data.model_dump(),
user_id=job_id_data.user_id,
paths_to_export=paths_to_export,
)
return AsyncJobGet(job_id=task_uuid)
return AsyncJobGet(job_id=task_uuid, job_name=task_name)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
from datetime import timedelta
from typing import Final

from celery.result import AsyncResult # type: ignore[import-untyped]
from models_library.progress_bar import ProgressReport
from pydantic import ValidationError
from servicelib.redis._client import RedisClientSDK

from ..models import TaskContext, TaskID, TaskMetadata, TaskUUID, build_task_id_prefix
from ..models import (
Task,
TaskContext,
TaskID,
TaskMetadata,
TaskUUID,
build_task_id_prefix,
)

_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
Expand All @@ -26,26 +33,64 @@ class RedisTaskInfoStore:
def __init__(self, redis_client_sdk: RedisClientSDK) -> None:
self._redis_client_sdk = redis_client_sdk

async def exists(self, task_id: TaskID) -> bool:
async def create_task(
self,
task_id: TaskID,
task_metadata: TaskMetadata,
expiry: timedelta,
) -> None:
task_key = _build_key(task_id)
await self._redis_client_sdk.redis.hset(
name=task_key,
key=_CELERY_TASK_METADATA_KEY,
value=task_metadata.model_dump_json(),
) # type: ignore
await self._redis_client_sdk.redis.expire(
task_key,
expiry,
)

async def exists_task(self, task_id: TaskID) -> bool:
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
assert isinstance(n, int) # nosec
return n > 0

async def get_metadata(self, task_id: TaskID) -> TaskMetadata | None:
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
return TaskMetadata.model_validate_json(result) if result else None
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
if not raw_result:
return None

try:
return TaskMetadata.model_validate_json(raw_result)
except ValidationError as exc:
_logger.debug(
"Failed to deserialize task metadata for task %s: %s", task_id, f"{exc}"
)
return None

async def get_progress(self, task_id: TaskID) -> ProgressReport | None:
result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
return ProgressReport.model_validate_json(result) if result else None
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
if not raw_result:
return None

async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
try:
return ProgressReport.model_validate_json(raw_result)
except ValidationError as exc:
_logger.debug(
"Failed to deserialize task progress for task %s: %s", task_id, f"{exc}"
)
return None

async def list_tasks(self, task_context: TaskContext) -> list[Task]:
search_key = (
_CELERY_TASK_INFO_PREFIX
+ build_task_id_prefix(task_context)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
keys = set()
search_key_len = len(search_key)

keys: list[str] = []
pipe = self._redis_client_sdk.redis.pipeline()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
):
Expand All @@ -55,27 +100,35 @@ async def get_uuids(self, task_context: TaskContext) -> set[TaskUUID]:
if isinstance(key, bytes)
else key
)
keys.add(TaskUUID(_key.removeprefix(search_key)))
return keys
keys.append(_key)
pipe.hget(_key, _CELERY_TASK_METADATA_KEY)

async def remove(self, task_id: TaskID) -> None:
await self._redis_client_sdk.redis.delete(_build_key(task_id))
AsyncResult(task_id).forget()
results = await pipe.execute()

async def set_metadata(
self, task_id: TaskID, task_metadata: TaskMetadata, expiry: timedelta
) -> None:
await self._redis_client_sdk.redis.hset(
name=_build_key(task_id),
key=_CELERY_TASK_METADATA_KEY,
value=task_metadata.model_dump_json(),
) # type: ignore
await self._redis_client_sdk.redis.expire(
_build_key(task_id),
expiry,
)
tasks = []
for key, raw_metadata in zip(keys, results, strict=True):
if raw_metadata is None:
continue

try:
task_metadata = TaskMetadata.model_validate_json(raw_metadata)
tasks.append(
Task(
uuid=TaskUUID(key[search_key_len:]),
metadata=task_metadata,
)
)
except ValidationError as exc:
_logger.debug(
"Failed to deserialize task metadata for key %s: %s", key, f"{exc}"
)

return tasks

async def remove_task(self, task_id: TaskID) -> None:
await self._redis_client_sdk.redis.delete(_build_key(task_id))

async def set_progress(self, task_id: TaskID, report: ProgressReport) -> None:
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
await self._redis_client_sdk.redis.hset(
name=_build_key(task_id),
key=_CELERY_TASK_PROGRESS_KEY,
Expand Down
Loading
Loading