Skip to content

Commit 80a7cca

Browse files
committed
transform run function endpoint
1 parent 8e49543 commit 80a7cca

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

services/api-server/src/simcore_service_api_server/api/dependencies/celery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Final
2+
13
from celery_library.common import create_app, create_task_manager
24
from celery_library.task_manager import CeleryTaskManager
35
from celery_library.types import register_celery_types, register_pydantic_types
@@ -13,6 +15,8 @@
1315
from ...models.api_resources import JobLinks
1416
from ...models.schemas.jobs import JobPricingSpecification
1517

18+
ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER"
19+
1620

1721
def setup_task_manager(app: FastAPI, celery_settings: CelerySettings) -> None:
1822
async def on_startup() -> None:

services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Callable
33
from typing import Annotated, Final, Literal
44

5+
from celery_library.task_manager import CeleryTaskManager
56
from fastapi import APIRouter, Depends, Header, Request, status
67
from fastapi_pagination.api import create_page
78
from fastapi_pagination.bases import AbstractPage
@@ -16,19 +17,29 @@
1617
RegisteredFunctionJob,
1718
RegisteredFunctionJobCollection,
1819
)
20+
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
1921
from models_library.products import ProductName
2022
from models_library.projects import ProjectID
2123
from models_library.projects_nodes_io import NodeID
2224
from models_library.users import UserID
25+
from servicelib.celery.models import TaskFilter, TaskMetadata
2326
from servicelib.fastapi.dependencies import get_reverse_url_mapper
27+
from servicelib.long_running_tasks.models import TaskGet
2428

2529
from ..._service_function_jobs import FunctionJobService
2630
from ..._service_functions import FunctionService
31+
from ...celery._worker_tasks._functions_tasks import run_function
2732
from ...models.pagination import Page, PaginationParams
2833
from ...models.schemas.errors import ErrorGet
2934
from ...models.schemas.jobs import JobPricingSpecification
3035
from ...services_rpc.wb_api_server import WbApiRpcClient
31-
from ..dependencies.authentication import get_current_user_id, get_product_name
36+
from ..dependencies.authentication import (
37+
Identity,
38+
get_current_identity,
39+
get_current_user_id,
40+
get_product_name,
41+
)
42+
from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager_from_app
3243
from ..dependencies.services import (
3344
get_function_job_service,
3445
get_function_service,
@@ -304,7 +315,7 @@ async def validate_function_inputs(
304315

305316
@function_router.post(
306317
"/{function_id:uuid}:run",
307-
response_model=RegisteredFunctionJob,
318+
response_model=TaskGet,
308319
responses={**_COMMON_FUNCTION_ERROR_RESPONSES},
309320
description=create_route_description(
310321
base="Run function",
@@ -313,16 +324,15 @@ async def validate_function_inputs(
313324
)
314325
async def run_function( # noqa: PLR0913
315326
request: Request,
327+
user_identity: Annotated[Identity, Depends(get_current_identity)],
316328
to_run_function: Annotated[RegisteredFunction, Depends(get_function)],
317329
url_for: Annotated[Callable, Depends(get_reverse_url_mapper)],
318330
function_inputs: FunctionInputs,
319331
function_service: Annotated[FunctionService, Depends(get_function_service)],
320-
function_jobs_service: Annotated[
321-
FunctionJobService, Depends(get_function_job_service)
322-
],
332+
task_manager: Annotated[CeleryTaskManager, Depends(get_task_manager_from_app)],
323333
x_simcore_parent_project_uuid: Annotated[ProjectID | Literal["null"], Header()],
324334
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
325-
) -> RegisteredFunctionJob:
335+
) -> TaskGet:
326336
parent_project_uuid = (
327337
x_simcore_parent_project_uuid
328338
if isinstance(x_simcore_parent_project_uuid, ProjectID)
@@ -336,7 +346,20 @@ async def run_function( # noqa: PLR0913
336346
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
337347
job_links = await function_service.get_function_job_links(to_run_function, url_for)
338348

339-
return await function_jobs_service.run_function(
349+
job_filter = AsyncJobFilter(
350+
user_id=user_identity.user_id,
351+
product_name=user_identity.product_name,
352+
client_name=ASYNC_JOB_CLIENT_NAME,
353+
)
354+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
355+
task_name = run_function.__name__
356+
357+
task_uuid = await task_manager.submit_task(
358+
TaskMetadata(
359+
name=task_name,
360+
),
361+
task_filter=task_filter,
362+
user_identity=user_identity,
340363
function=to_run_function,
341364
function_inputs=function_inputs,
342365
pricing_spec=pricing_spec,
@@ -345,6 +368,14 @@ async def run_function( # noqa: PLR0913
345368
x_simcore_parent_node_id=parent_node_id,
346369
)
347370

371+
return TaskGet(
372+
task_id=f"{task_uuid}",
373+
task_name=task_name,
374+
status_href=url_for("get_task_status", task_id=task_uuid),
375+
result_href=url_for("get_task_result", task_id=task_uuid),
376+
abort_href=url_for("cancel_task", task_id=task_uuid),
377+
)
378+
348379

349380
@function_router.delete(
350381
"/{function_id:uuid}",

services/api-server/src/simcore_service_api_server/api/routes/tasks.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Annotated, Any, Final
2+
from typing import Annotated, Any
33

44
from fastapi import APIRouter, Depends, FastAPI, status
55
from models_library.api_schemas_long_running_tasks.base import TaskProgress
@@ -20,21 +20,20 @@
2020
from ...models.schemas.errors import ErrorGet
2121
from ...services_rpc.async_jobs import AsyncJobClient
2222
from ..dependencies.authentication import get_current_user_id, get_product_name
23+
from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME
2324
from ..dependencies.tasks import get_async_jobs_client
2425
from ._constants import (
2526
FMSG_CHANGELOG_NEW_IN_VERSION,
2627
create_route_description,
2728
)
2829

29-
_ASYNC_JOB_CLIENT_NAME: Final[str] = "API_SERVER"
30-
3130
router = APIRouter()
3231
_logger = logging.getLogger(__name__)
3332

3433

3534
def _get_job_filter(user_id: UserID, product_name: ProductName) -> AsyncJobFilter:
3635
return AsyncJobFilter(
37-
user_id=user_id, product_name=product_name, client_name=_ASYNC_JOB_CLIENT_NAME
36+
user_id=user_id, product_name=product_name, client_name=ASYNC_JOB_CLIENT_NAME
3837
)
3938

4039

0 commit comments

Comments
 (0)