11import logging
22from typing import Annotated , Any
33
4- from fastapi import APIRouter , Depends , FastAPI , status
4+ from celery .exceptions import CeleryError # type: ignore[import-untyped]
5+ from fastapi import APIRouter , Depends , FastAPI , HTTPException , status
56from models_library .api_schemas_long_running_tasks .base import TaskProgress
67from models_library .api_schemas_long_running_tasks .tasks import (
78 TaskGet ,
1415)
1516from models_library .products import ProductName
1617from models_library .users import UserID
18+ from servicelib .celery .models import TaskFilter , TaskUUID
1719from servicelib .fastapi .dependencies import get_app
1820
1921from ...models .schemas .base import ApiServerEnvelope
2022from ...models .schemas .errors import ErrorGet
21- from ...services_rpc .async_jobs import AsyncJobClient
2223from ..dependencies .authentication import get_current_user_id , get_product_name
23- from ..dependencies .celery import ASYNC_JOB_CLIENT_NAME
24- from ..dependencies .tasks import get_async_jobs_client
24+ from ..dependencies .celery import ASYNC_JOB_CLIENT_NAME , get_task_manager_from_app
2525from ._constants import (
2626 FMSG_CHANGELOG_NEW_IN_VERSION ,
2727 create_route_description ,
3131_logger = logging .getLogger (__name__ )
3232
3333
34- def _get_job_filter (user_id : UserID , product_name : ProductName ) -> AsyncJobFilter :
35- return AsyncJobFilter (
34+ def _get_task_filter (user_id : UserID , product_name : ProductName ) -> TaskFilter :
35+ job_filter = AsyncJobFilter (
3636 user_id = user_id , product_name = product_name , client_name = ASYNC_JOB_CLIENT_NAME
3737 )
38+ return TaskFilter .model_validate (job_filter .model_dump ())
3839
3940
4041_DEFAULT_TASK_STATUS_CODES : dict [int | str , dict [str , Any ]] = {
@@ -61,26 +62,34 @@ async def list_tasks(
6162 app : Annotated [FastAPI , Depends (get_app )],
6263 user_id : Annotated [UserID , Depends (get_current_user_id )],
6364 product_name : Annotated [ProductName , Depends (get_product_name )],
64- async_jobs : Annotated [AsyncJobClient , Depends (get_async_jobs_client )],
6565):
66- user_async_jobs = await async_jobs .list_jobs (
67- job_filter = _get_job_filter (user_id , product_name ),
68- filter_ = "" ,
69- )
66+
67+ task_manager = get_task_manager_from_app (app )
68+
69+ try :
70+ tasks = await task_manager .list_tasks (
71+ task_filter = _get_task_filter (user_id , product_name ),
72+ )
73+ except CeleryError as exc :
74+ raise HTTPException (
75+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
76+ detail = "Encountered issue when listing tasks" ,
77+ ) from exc
78+
7079 app_router = app .router
7180 data = [
7281 TaskGet (
73- task_id = f"{ job . job_id } " ,
74- task_name = job . job_name ,
82+ task_id = f"{ task . uuid } " ,
83+ task_name = task . metadata . name ,
7584 status_href = app_router .url_path_for (
76- "get_task_status" , task_id = f"{ job . job_id } "
85+ "get_task_status" , task_id = f"{ task . uuid } "
7786 ),
78- abort_href = app_router .url_path_for ("cancel_task" , task_id = f"{ job . job_id } " ),
87+ abort_href = app_router .url_path_for ("cancel_task" , task_id = f"{ task . uuid } " ),
7988 result_href = app_router .url_path_for (
80- "get_task_result" , task_id = f"{ job . job_id } "
89+ "get_task_result" , task_id = f"{ task . uuid } "
8190 ),
8291 )
83- for job in user_async_jobs
92+ for task in tasks
8493 ]
8594 return ApiServerEnvelope (data = data )
8695
@@ -99,20 +108,29 @@ async def list_tasks(
99108)
100109async def get_task_status (
101110 task_id : AsyncJobId ,
111+ app : Annotated [FastAPI , Depends (get_app )],
102112 user_id : Annotated [UserID , Depends (get_current_user_id )],
103113 product_name : Annotated [ProductName , Depends (get_product_name )],
104- async_jobs : Annotated [AsyncJobClient , Depends (get_async_jobs_client )],
105114):
106- async_job_rpc_status = await async_jobs .status (
107- job_id = task_id ,
108- job_filter = _get_job_filter (user_id , product_name ),
109- )
110- _task_id = f"{ async_job_rpc_status .job_id } "
115+ task_manager = get_task_manager_from_app (app )
116+
117+ try :
118+ task_status = await task_manager .get_task_status (
119+ task_filter = _get_task_filter (user_id , product_name ),
120+ task_uuid = TaskUUID (f"{ task_id } " ),
121+ )
122+ except CeleryError as exc :
123+ raise HTTPException (
124+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
125+ detail = "Encountered issue when getting task status" ,
126+ ) from exc
127+
111128 return TaskStatus (
112129 task_progress = TaskProgress (
113- task_id = _task_id , percent = async_job_rpc_status .progress .percent_value
130+ task_id = f"{ task_status .task_uuid } " ,
131+ percent = task_status .progress_report .percent_value ,
114132 ),
115- done = async_job_rpc_status . done ,
133+ done = task_status . is_done ,
116134 started = None ,
117135 )
118136
@@ -131,14 +149,22 @@ async def get_task_status(
131149)
132150async def cancel_task (
133151 task_id : AsyncJobId ,
152+ app : Annotated [FastAPI , Depends (get_app )],
134153 user_id : Annotated [UserID , Depends (get_current_user_id )],
135154 product_name : Annotated [ProductName , Depends (get_product_name )],
136- async_jobs : Annotated [AsyncJobClient , Depends (get_async_jobs_client )],
137155):
138- await async_jobs .cancel (
139- job_id = task_id ,
140- job_filter = _get_job_filter (user_id , product_name ),
141- )
156+ task_manager = get_task_manager_from_app (app )
157+
158+ try :
159+ await task_manager .cancel_task (
160+ task_filter = _get_task_filter (user_id , product_name ),
161+ task_uuid = TaskUUID (f"{ task_id } " ),
162+ )
163+ except CeleryError as exc :
164+ raise HTTPException (
165+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
166+ detail = "Encountered issue when cancelling task" ,
167+ ) from exc
142168
143169
144170@router .get (
@@ -165,12 +191,34 @@ async def cancel_task(
165191)
166192async def get_task_result (
167193 task_id : AsyncJobId ,
194+ app : Annotated [FastAPI , Depends (get_app )],
168195 user_id : Annotated [UserID , Depends (get_current_user_id )],
169196 product_name : Annotated [ProductName , Depends (get_product_name )],
170- async_jobs : Annotated [AsyncJobClient , Depends (get_async_jobs_client )],
171197):
172- async_job_rpc_result = await async_jobs .result (
173- job_id = task_id ,
174- job_filter = _get_job_filter (user_id , product_name ),
175- )
176- return TaskResult (result = async_job_rpc_result .result , error = None )
198+ task_manager = get_task_manager_from_app (app )
199+ task_filter = _get_task_filter (user_id , product_name )
200+
201+ try :
202+ # First check if task exists and is done
203+ task_status = await task_manager .get_task_status (
204+ task_filter = task_filter ,
205+ task_uuid = TaskUUID (f"{ task_id } " ),
206+ )
207+
208+ if not task_status .is_done :
209+ raise HTTPException (
210+ status_code = status .HTTP_404_NOT_FOUND ,
211+ detail = "Task result not available yet" ,
212+ )
213+
214+ result = await task_manager .get_task_result (
215+ task_filter = task_filter ,
216+ task_uuid = TaskUUID (f"{ task_id } " ),
217+ )
218+ return TaskResult (result = result , error = None )
219+
220+ except CeleryError as exc :
221+ raise HTTPException (
222+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
223+ detail = "Encountered issue when getting task result" ,
224+ ) from exc
0 commit comments