Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import logging

from celery.exceptions import CeleryError # type: ignore[import-untyped]
from celery_library.errors import (
TransferrableCeleryError,
decode_celery_transferrable_error,
)
from fastapi import FastAPI
from models_library.api_schemas_rpc_async_jobs.async_jobs import (
AsyncJobGet,
AsyncJobId,
Expand All @@ -22,21 +17,27 @@
JobSchedulerError,
)
from servicelib.celery.models import TaskState
from servicelib.celery.task_manager import TaskManager
from servicelib.logging_utils import log_catch
from servicelib.rabbitmq import RPCRouter

from ...modules.celery import get_task_manager_from_app
from ..errors import (
TransferrableCeleryError,
decode_celery_transferrable_error,
)

_logger = logging.getLogger(__name__)
router = RPCRouter()


@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData):
assert app # nosec
async def cancel(
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
):
assert task_manager # nosec
assert job_id_data # nosec
try:
await get_task_manager_from_app(app).cancel_task(
await task_manager.cancel_task(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
)
Expand All @@ -46,13 +47,13 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData

@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def status(
app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
) -> AsyncJobStatus:
assert app # nosec
assert task_manager # nosec
assert job_id_data # nosec

try:
task_status = await get_task_manager_from_app(app).get_task_status(
task_status = await task_manager.get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
)
Expand All @@ -75,20 +76,20 @@ async def status(
)
)
async def result(
app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData
task_manager: TaskManager, job_id: AsyncJobId, job_id_data: AsyncJobNameData
) -> AsyncJobResult:
assert app # nosec
assert task_manager # nosec
assert job_id # nosec
assert job_id_data # nosec

try:
_status = await get_task_manager_from_app(app).get_task_status(
_status = await task_manager.get_task_status(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
)
if not _status.is_done:
raise JobNotDoneError(job_id=job_id)
_result = await get_task_manager_from_app(app).get_task_result(
_result = await task_manager.get_task_result(
task_context=job_id_data.model_dump(),
task_uuid=job_id,
)
Expand Down Expand Up @@ -122,12 +123,12 @@ async def result(

@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def list_jobs(
app: FastAPI, filter_: str, job_id_data: AsyncJobNameData
task_manager: TaskManager, filter_: str, job_id_data: AsyncJobNameData
) -> list[AsyncJobGet]:
_ = filter_
assert app # nosec
assert task_manager # nosec
try:
tasks = await get_task_manager_from_app(app).list_tasks(
tasks = await task_manager.list_tasks(
task_context=job_id_data.model_dump(),
)
except CeleryError as exc:
Expand Down
4 changes: 4 additions & 0 deletions packages/celery-library/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from celery_library.common import create_task_manager
from celery_library.signals import on_worker_init, on_worker_shutdown
from celery_library.task_manager import CeleryTaskManager
from celery_library.types import register_celery_types
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
from pytest_simcore.helpers.typing_env import EnvVarsDict
from servicelib.celery.app_server import BaseAppServer
Expand All @@ -25,6 +26,7 @@
"pytest_simcore.docker_compose",
"pytest_simcore.docker_swarm",
"pytest_simcore.environment_configs",
"pytest_simcore.rabbit_service",
"pytest_simcore.redis_service",
"pytest_simcore.repository_paths",
]
Expand Down Expand Up @@ -123,6 +125,8 @@ async def celery_task_manager(
celery_settings: CelerySettings,
with_celery_worker: TestWorkController,
) -> CeleryTaskManager:
register_celery_types()

return await create_task_manager(
celery_app,
celery_settings,
Expand Down
Loading
Loading