Skip to content

Commit 162c8ee

Browse files
committed
test passing
1 parent 676aa32 commit 162c8ee

File tree

2 files changed

+87
-41
lines changed

2 files changed

+87
-41
lines changed

services/api-server/src/simcore_service_api_server/api/routes/tasks.py

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
from typing import Annotated, Any
33

4-
from fastapi import APIRouter, Depends, FastAPI, status
4+
from celery.exceptions import CeleryError # type: ignore[import-untyped]
5+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, status
56
from models_library.api_schemas_long_running_tasks.base import TaskProgress
67
from models_library.api_schemas_long_running_tasks.tasks import (
78
TaskGet,
@@ -14,14 +15,13 @@
1415
)
1516
from models_library.products import ProductName
1617
from models_library.users import UserID
18+
from servicelib.celery.models import TaskFilter, TaskUUID
1719
from servicelib.fastapi.dependencies import get_app
1820

1921
from ...models.schemas.base import ApiServerEnvelope
2022
from ...models.schemas.errors import ErrorGet
21-
from ...services_rpc.async_jobs import AsyncJobClient
2223
from ..dependencies.authentication import get_current_user_id, get_product_name
23-
from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME
24-
from ..dependencies.tasks import get_async_jobs_client
24+
from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager_from_app
2525
from ._constants import (
2626
FMSG_CHANGELOG_NEW_IN_VERSION,
2727
create_route_description,
@@ -31,10 +31,11 @@
3131
_logger = logging.getLogger(__name__)
3232

3333

34-
def _get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter:
35-
return AsyncJobFilter(
34+
def _get_task_filter(user_id: UserID, product_name: ProductName) -> TaskFilter:
35+
job_filter = AsyncJobFilter(
3636
user_id=user_id, product_name=product_name, client_name=ASYNC_JOB_CLIENT_NAME
3737
)
38+
return TaskFilter.model_validate(job_filter.model_dump())
3839

3940

4041
_DEFAULT_TASK_STATUS_CODES: dict[int | str, dict[str, Any]] = {
@@ -61,26 +62,34 @@ async def list_tasks(
6162
app: Annotated[FastAPI, Depends(get_app)],
6263
user_id: Annotated[UserID, Depends(get_current_user_id)],
6364
product_name: Annotated[ProductName, Depends(get_product_name)],
64-
async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)],
6565
):
66-
user_async_jobs = await async_jobs.list_jobs(
67-
job_filter=_get_job_filter(user_id, product_name),
68-
filter_="",
69-
)
66+
67+
task_manager = get_task_manager_from_app(app)
68+
69+
try:
70+
tasks = await task_manager.list_tasks(
71+
task_filter=_get_task_filter(user_id, product_name),
72+
)
73+
except CeleryError as exc:
74+
raise HTTPException(
75+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
76+
detail="Encountered issue when listing tasks",
77+
) from exc
78+
7079
app_router = app.router
7180
data = [
7281
TaskGet(
73-
task_id=f"{job.job_id}",
74-
task_name=job.job_name,
82+
task_id=f"{task.uuid}",
83+
task_name=task.metadata.name,
7584
status_href=app_router.url_path_for(
76-
"get_task_status", task_id=f"{job.job_id}"
85+
"get_task_status", task_id=f"{task.uuid}"
7786
),
78-
abort_href=app_router.url_path_for("cancel_task", task_id=f"{job.job_id}"),
87+
abort_href=app_router.url_path_for("cancel_task", task_id=f"{task.uuid}"),
7988
result_href=app_router.url_path_for(
80-
"get_task_result", task_id=f"{job.job_id}"
89+
"get_task_result", task_id=f"{task.uuid}"
8190
),
8291
)
83-
for job in user_async_jobs
92+
for task in tasks
8493
]
8594
return ApiServerEnvelope(data=data)
8695

@@ -99,20 +108,29 @@ async def list_tasks(
99108
)
100109
async def get_task_status(
101110
task_id: AsyncJobId,
111+
app: Annotated[FastAPI, Depends(get_app)],
102112
user_id: Annotated[UserID, Depends(get_current_user_id)],
103113
product_name: Annotated[ProductName, Depends(get_product_name)],
104-
async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)],
105114
):
106-
async_job_rpc_status = await async_jobs.status(
107-
job_id=task_id,
108-
job_filter=_get_job_filter(user_id, product_name),
109-
)
110-
_task_id = f"{async_job_rpc_status.job_id}"
115+
task_manager = get_task_manager_from_app(app)
116+
117+
try:
118+
task_status = await task_manager.get_task_status(
119+
task_filter=_get_task_filter(user_id, product_name),
120+
task_uuid=TaskUUID(f"{task_id}"),
121+
)
122+
except CeleryError as exc:
123+
raise HTTPException(
124+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
125+
detail="Encountered issue when getting task status",
126+
) from exc
127+
111128
return TaskStatus(
112129
task_progress=TaskProgress(
113-
task_id=_task_id, percent=async_job_rpc_status.progress.percent_value
130+
task_id=f"{task_status.task_uuid}",
131+
percent=task_status.progress_report.percent_value,
114132
),
115-
done=async_job_rpc_status.done,
133+
done=task_status.is_done,
116134
started=None,
117135
)
118136

@@ -131,14 +149,22 @@ async def get_task_status(
131149
)
132150
async def cancel_task(
133151
task_id: AsyncJobId,
152+
app: Annotated[FastAPI, Depends(get_app)],
134153
user_id: Annotated[UserID, Depends(get_current_user_id)],
135154
product_name: Annotated[ProductName, Depends(get_product_name)],
136-
async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)],
137155
):
138-
await async_jobs.cancel(
139-
job_id=task_id,
140-
job_filter=_get_job_filter(user_id, product_name),
141-
)
156+
task_manager = get_task_manager_from_app(app)
157+
158+
try:
159+
await task_manager.cancel_task(
160+
task_filter=_get_task_filter(user_id, product_name),
161+
task_uuid=TaskUUID(f"{task_id}"),
162+
)
163+
except CeleryError as exc:
164+
raise HTTPException(
165+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
166+
detail="Encountered issue when cancelling task",
167+
) from exc
142168

143169

144170
@router.get(
@@ -165,12 +191,34 @@ async def cancel_task(
165191
)
166192
async def get_task_result(
167193
task_id: AsyncJobId,
194+
app: Annotated[FastAPI, Depends(get_app)],
168195
user_id: Annotated[UserID, Depends(get_current_user_id)],
169196
product_name: Annotated[ProductName, Depends(get_product_name)],
170-
async_jobs: Annotated[AsyncJobClient, Depends(get_async_jobs_client)],
171197
):
172-
async_job_rpc_result = await async_jobs.result(
173-
job_id=task_id,
174-
job_filter=_get_job_filter(user_id, product_name),
175-
)
176-
return TaskResult(result=async_job_rpc_result.result, error=None)
198+
task_manager = get_task_manager_from_app(app)
199+
task_filter = _get_task_filter(user_id, product_name)
200+
201+
try:
202+
# First check if task exists and is done
203+
task_status = await task_manager.get_task_status(
204+
task_filter=task_filter,
205+
task_uuid=TaskUUID(f"{task_id}"),
206+
)
207+
208+
if not task_status.is_done:
209+
raise HTTPException(
210+
status_code=status.HTTP_404_NOT_FOUND,
211+
detail="Task result not available yet",
212+
)
213+
214+
result = await task_manager.get_task_result(
215+
task_filter=task_filter,
216+
task_uuid=TaskUUID(f"{task_id}"),
217+
)
218+
return TaskResult(result=result, error=None)
219+
220+
except CeleryError as exc:
221+
raise HTTPException(
222+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
223+
detail="Encountered issue when getting task result",
224+
) from exc

services/api-server/tests/unit/celery/test_functions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from celery import Celery, Task
66
from celery.contrib.testing.worker import TestWorkController
77
from celery_library.task import register_task
8+
from celery_library.types import register_pydantic_types
89
from faker import Faker
910
from fastapi import FastAPI, status
1011
from httpx import AsyncClient, BasicAuth
@@ -111,6 +112,7 @@ async def run_function(
111112
), f"Signature mismatch: {inspect.signature(run_function_task)} != {inspect.signature(run_function)}"
112113

113114
def _(celery_app: Celery) -> None:
115+
register_pydantic_types(RegisteredProjectFunctionJob)
114116
register_task(celery_app, run_function)
115117

116118
return _
@@ -148,8 +150,4 @@ async def test_with_fake_run_function(
148150

149151
# Poll until task completion and get result
150152
result = await poll_task_until_done(client, auth, task.task_id)
151-
152-
# Verify the result is a RegisteredProjectFunctionJob
153-
assert result is not None
154-
assert isinstance(result, dict)
155-
# Add more specific assertions based on your expected result structure
153+
RegisteredProjectFunctionJob.model_validate(result.result)

0 commit comments

Comments
 (0)