Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,24 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
from servicelib.logging_utils import log_catch
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_celery_client
from ...modules.celery.client import CeleryTaskQueueClient
from ...modules.celery.models import TaskError, TaskState

_logger = logging.getLogger(__name__)
router = RPCRouter()


async def _assert_job_exists(
*,
job_id: AsyncJobId,
job_id_data: AsyncJobNameData,
celery_client: CeleryTaskQueueClient,
) -> None:
"""Raises JobMissingError if job doesn't exist"""
job_ids = await celery_client.get_task_uuids(
task_context=job_id_data.model_dump(),
)
if job_id not in job_ids:
raise JobMissingError(job_id=f"{job_id}")


@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData):
assert app # nosec
assert job_id_data # nosec
try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
await get_celery_client(app).abort_task(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand All @@ -59,17 +40,14 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
raise JobSchedulerError(exc=f"{exc}") from exc


@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def status(
app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
) -> AsyncJobStatus:
assert app # nosec
assert job_id_data # nosec

try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
task_status = await get_celery_client(app).get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand All @@ -90,7 +68,6 @@ async def status(
JobNotDoneError,
JobAbortedError,
JobSchedulerError,
JobMissingError,
)
)
async def result(
Expand All @@ -101,9 +78,6 @@ async def result(
assert job_id_data # nosec

try:
await _assert_job_exists(
job_id=job_id, job_id_data=job_id_data, celery_client=get_celery_client(app)
)
_status = await get_celery_client(app).get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
Expand Down
21 changes: 14 additions & 7 deletions services/storage/tests/unit/test_data_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from models_library.api_schemas_rpc_async_jobs.exceptions import (
JobAbortedError,
JobError,
JobMissingError,
JobNotDoneError,
JobSchedulerError,
)
Expand Down Expand Up @@ -343,7 +342,6 @@ async def test_abort_data_export_success(
@pytest.mark.parametrize(
"mock_celery_client, expected_exception_type",
[
({"abort_task_object": None, "get_task_uuids_object": []}, JobMissingError),
(
{
"abort_task_object": CeleryError("error"),
Expand Down Expand Up @@ -377,6 +375,14 @@ async def test_abort_data_export_error(
@pytest.mark.parametrize(
"mock_celery_client",
[
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
task_state=TaskState.PENDING,
progress_report=ProgressReport(actual_value=0),
),
"get_task_uuids_object": [],
},
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
Expand Down Expand Up @@ -411,10 +417,6 @@ async def test_get_data_export_status(
@pytest.mark.parametrize(
"mock_celery_client, expected_exception_type",
[
(
{"get_task_status_object": None, "get_task_uuids_object": []},
JobMissingError,
),
(
{
"get_task_status_object": CeleryError("error"),
Expand Down Expand Up @@ -528,9 +530,14 @@ async def test_get_data_export_result_success(
),
(
{
"get_task_status_object": TaskStatus(
task_uuid=TaskUUID(_faker.uuid4()),
task_state=TaskState.PENDING,
progress_report=ProgressReport(actual_value=0.0),
),
"get_task_uuids_object": [],
},
JobMissingError,
JobNotDoneError,
),
],
indirect=["mock_celery_client"],
Expand Down
Loading