22from collections .abc import Callable
33from typing import Annotated , Final , Literal
44
5+ from celery_library .task_manager import CeleryTaskManager
56from fastapi import APIRouter , Depends , Header , Request , status
67from fastapi_pagination .api import create_page
78from fastapi_pagination .bases import AbstractPage
1617 RegisteredFunctionJob ,
1718 RegisteredFunctionJobCollection ,
1819)
20+ from models_library .api_schemas_rpc_async_jobs .async_jobs import AsyncJobFilter
1921from models_library .products import ProductName
2022from models_library .projects import ProjectID
2123from models_library .projects_nodes_io import NodeID
2224from models_library .users import UserID
25+ from servicelib .celery .models import TaskFilter , TaskMetadata
2326from servicelib .fastapi .dependencies import get_reverse_url_mapper
27+ from servicelib .long_running_tasks .models import TaskGet
2428
2529from ..._service_function_jobs import FunctionJobService
2630from ..._service_functions import FunctionService
31+ from ...celery ._worker_tasks ._functions_tasks import run_function
2732from ...models .pagination import Page , PaginationParams
2833from ...models .schemas .errors import ErrorGet
2934from ...models .schemas .jobs import JobPricingSpecification
3035from ...services_rpc .wb_api_server import WbApiRpcClient
31- from ..dependencies .authentication import get_current_user_id , get_product_name
36+ from ..dependencies .authentication import (
37+ Identity ,
38+ get_current_identity ,
39+ get_current_user_id ,
40+ get_product_name ,
41+ )
42+ from ..dependencies .celery import ASYNC_JOB_CLIENT_NAME , get_task_manager_from_app
3243from ..dependencies .services import (
3344 get_function_job_service ,
3445 get_function_service ,
@@ -304,7 +315,7 @@ async def validate_function_inputs(
304315
305316@function_router .post (
306317 "/{function_id:uuid}:run" ,
307- response_model = RegisteredFunctionJob ,
318+ response_model = TaskGet ,
308319 responses = {** _COMMON_FUNCTION_ERROR_RESPONSES },
309320 description = create_route_description (
310321 base = "Run function" ,
@@ -313,16 +324,15 @@ async def validate_function_inputs(
313324)
314325async def run_function ( # noqa: PLR0913
315326 request : Request ,
327+ user_identity : Annotated [Identity , Depends (get_current_identity )],
316328 to_run_function : Annotated [RegisteredFunction , Depends (get_function )],
317329 url_for : Annotated [Callable , Depends (get_reverse_url_mapper )],
318330 function_inputs : FunctionInputs ,
319331 function_service : Annotated [FunctionService , Depends (get_function_service )],
320- function_jobs_service : Annotated [
321- FunctionJobService , Depends (get_function_job_service )
322- ],
332+ task_manager : Annotated [CeleryTaskManager , Depends (get_task_manager_from_app )],
323333 x_simcore_parent_project_uuid : Annotated [ProjectID | Literal ["null" ], Header ()],
324334 x_simcore_parent_node_id : Annotated [NodeID | Literal ["null" ], Header ()],
325- ) -> RegisteredFunctionJob :
335+ ) -> TaskGet :
326336 parent_project_uuid = (
327337 x_simcore_parent_project_uuid
328338 if isinstance (x_simcore_parent_project_uuid , ProjectID )
@@ -336,7 +346,20 @@ async def run_function( # noqa: PLR0913
336346 pricing_spec = JobPricingSpecification .create_from_headers (request .headers )
337347 job_links = await function_service .get_function_job_links (to_run_function , url_for )
338348
339- return await function_jobs_service .run_function (
349+ job_filter = AsyncJobFilter (
350+ user_id = user_identity .user_id ,
351+ product_name = user_identity .product_name ,
352+ client_name = ASYNC_JOB_CLIENT_NAME ,
353+ )
354+ task_filter = TaskFilter .model_validate (job_filter .model_dump ())
355+ task_name = run_function .__name__
356+
357+ task_uuid = await task_manager .submit_task (
358+ TaskMetadata (
359+ name = task_name ,
360+ ),
361+ task_filter = task_filter ,
362+ user_identity = user_identity ,
340363 function = to_run_function ,
341364 function_inputs = function_inputs ,
342365 pricing_spec = pricing_spec ,
@@ -345,6 +368,14 @@ async def run_function( # noqa: PLR0913
345368 x_simcore_parent_node_id = parent_node_id ,
346369 )
347370
371+ return TaskGet (
372+ task_id = f"{ task_uuid } " ,
373+ task_name = task_name ,
374+ status_href = url_for ("get_task_status" , task_id = task_uuid ),
375+ result_href = url_for ("get_task_result" , task_id = task_uuid ),
376+ abort_href = url_for ("cancel_task" , task_id = task_uuid ),
377+ )
378+
348379
349380@function_router .delete (
350381 "/{function_id:uuid}" ,
0 commit comments