Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,24 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
from servicelib.logging_utils import log_catch
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_celery_client
from ...modules.celery.client import CeleryTaskQueueClient
from ...modules.celery.models import TaskError, TaskState

_logger = logging.getLogger(__name__)
router = RPCRouter()


async def _assert_job_exists(
*,
job_id: AsyncJobId,
job_id_data: AsyncJobNameData,
celery_client: CeleryTaskQueueClient,
) -> None:
"""Raises JobMissingError if job doesn't exist"""
job_ids = await celery_client.get_task_uuids(
task_context=job_id_data.model_dump(),
)
if job_id not in job_ids:
raise JobMissingError(job_id=f"{job_id}")


@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData):
assert app # nosec
assert job_id_data # nosec
try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
await get_celery_client(app).abort_task(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand All @@ -59,17 +40,14 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
raise JobSchedulerError(exc=f"{exc}") from exc


@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def status(
app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
) -> AsyncJobStatus:
assert app # nosec
assert job_id_data # nosec

try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
task_status = await get_celery_client(app).get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand All @@ -90,7 +68,6 @@ async def status(
JobNotDoneError,
JobAbortedError,
JobSchedulerError,
JobMissingError,
)
)
async def result(
Expand All @@ -101,9 +78,6 @@ async def result(
assert job_id_data # nosec

try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
_status = await get_celery_client(app).get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand Down
52 changes: 32 additions & 20 deletions services/storage/tests/unit/test__worker_tasks_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import ByteSize, TypeAdapter
from pytest_simcore.helpers.storage_utils import FileIDDict, ProjectWithFilesParams
from simcore_service_storage.api._worker_tasks._paths import compute_path_size
from simcore_service_storage.modules.celery.models import TaskId
from simcore_service_storage.modules.celery.utils import set_fastapi_app
from simcore_service_storage.simcore_s3_dsm import SimcoreS3DataManager

Expand Down Expand Up @@ -48,15 +49,20 @@ def _filter_and_group_paths_one_level_deeper(


async def _assert_compute_path_size(
*,
celery_task: Task,
task_id: TaskId,
location_id: LocationID,
user_id: UserID,
*,
path: Path,
expected_total_size: int,
) -> ByteSize:
response = await compute_path_size(
celery_task, user_id=user_id, location_id=location_id, path=path
celery_task,
task_id=task_id,
user_id=user_id,
location_id=location_id,
path=path,
)
assert isinstance(response, ByteSize)
assert response == expected_total_size
Expand Down Expand Up @@ -111,9 +117,10 @@ async def test_path_compute_size(
expected_total_size = project_params.allowed_file_sizes[0] * total_num_files
path = Path(project["uuid"])
await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=path,
expected_total_size=expected_total_size,
)
Expand All @@ -128,9 +135,10 @@ async def test_path_compute_size(
selected_node_s3_keys
)
await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=path,
expected_total_size=expected_total_size,
)
Expand All @@ -146,9 +154,10 @@ async def test_path_compute_size(
selected_node_s3_keys
)
await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=path,
expected_total_size=expected_total_size,
)
Expand All @@ -164,9 +173,10 @@ async def test_path_compute_size(
selected_node_s3_keys
)
workspace_total_size = await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=path,
expected_total_size=expected_total_size,
)
Expand All @@ -188,9 +198,10 @@ async def test_path_compute_size(
selected_node_s3_keys
)
accumulated_subfolder_size += await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=workspace_subfolder,
expected_total_size=expected_total_size,
)
Expand All @@ -208,9 +219,10 @@ async def test_path_compute_size_inexistent_path(
fake_datcore_tokens: tuple[str, str],
):
await _assert_compute_path_size(
fake_celery_task,
location_id,
user_id,
celery_task=fake_celery_task,
task_id=TaskId("fake_task"),
location_id=location_id,
user_id=user_id,
path=Path(faker.file_path(absolute=False)),
expected_total_size=0,
)
21 changes: 14 additions & 7 deletions services/storage/tests/unit/test_data_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
Expand Down Expand Up @@ -343,7 +342,6 @@ async def test_abort_data_export_success(
@pytest.mark.parametrize(
"mock_celery_client, expected_exception_type",
[
({"abort_task_object": None, "get_task_uuids_object": []}, JobMissingError),
(
{
"abort_task_object": CeleryError("error"),
Expand Down Expand Up @@ -377,6 +375,14 @@ async def test_abort_data_export_error(
@pytest.mark.parametrize(
"mock_celery_client",
[
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
task_state=TaskState.PENDING,
progress_report=ProgressReport(actual_value=0),
),
"get_task_uuids_object": [],
},
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
Expand Down Expand Up @@ -411,10 +417,6 @@ async def test_get_data_export_status(
@pytest.mark.parametrize(
"mock_celery_client, expected_exception_type",
[
(
{"get_task_status_object": None, "get_task_uuids_object": []},
JobMissingError,
),
(
{
"get_task_status_object": CeleryError("error"),
Expand Down Expand Up @@ -528,9 +530,14 @@ async def test_get_data_export_result_success(
),
(
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
task_state=TaskState.PENDING,
progress_report=ProgressReport(actual_value=0.0),
),
"get_task_uuids_object": [],
},
JobMissingError,
JobNotDoneError,
),
],
indirect=["mock_celery_client"],
Expand Down
Loading