Skip to content

Commit a454ee7

Browse files
committed
further test fixes
1 parent 9c725bc commit a454ee7

File tree

4 files changed

+135
-91
lines changed

4 files changed

+135
-91
lines changed

services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import contextlib
12
from dataclasses import dataclass
23
from typing import Final
34

45
from common_library.exclude import as_dict_exclude_none
6+
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
57
from models_library.functions import (
68
FunctionClass,
79
FunctionID,
@@ -22,22 +24,27 @@
2224
UnsupportedFunctionFunctionJobClassCombinationError,
2325
)
2426
from models_library.products import ProductName
27+
from models_library.projects import ProjectID
28+
from models_library.projects_nodes_io import NodeID
2529
from models_library.projects_state import RunningState
2630
from models_library.rest_pagination import PageMetaInfoLimitOffset, PageOffsetInt
2731
from models_library.rpc_pagination import PageLimitInt
2832
from models_library.users import UserID
29-
from servicelib.celery.models import TaskUUID
33+
from servicelib.celery.models import TaskFilter, TaskMetadata, TasksQueue, TaskUUID
3034
from servicelib.celery.task_manager import TaskManager
3135
from sqlalchemy.ext.asyncio import AsyncEngine
3236

3337
from ._service_function_jobs import FunctionJobService
3438
from ._service_functions import FunctionService
3539
from ._service_jobs import JobService
40+
from .api.dependencies.authentication import Identity
41+
from .api.dependencies.celery import ASYNC_JOB_CLIENT_NAME
3642
from .api.routes.tasks import _get_task_filter
3743
from .exceptions.function_errors import (
3844
FunctionJobCacheNotFoundError,
3945
)
40-
from .models.schemas.jobs import JobInputs
46+
from .models.api_resources import JobLinks
47+
from .models.schemas.jobs import JobInputs, JobPricingSpecification
4148
from .services_http.webserver import AuthSession
4249
from .services_rpc.storage import StorageService
4350
from .services_rpc.wb_api_server import WbApiRpcClient
@@ -319,3 +326,65 @@ async def function_job_outputs(
319326
outputs=new_outputs,
320327
check_write_permissions=False,
321328
)
329+
330+
async def create_function_job_creation_task(
331+
self,
332+
*,
333+
function: RegisteredFunction,
334+
function_inputs: FunctionInputs,
335+
user_identity: Identity,
336+
pricing_spec: JobPricingSpecification | None,
337+
job_links: JobLinks,
338+
parent_project_uuid: ProjectID | None = None,
339+
parent_node_id: NodeID | None = None,
340+
) -> RegisteredFunctionJob:
341+
342+
job_inputs = await self._function_job_service.create_function_job_inputs(
343+
function=function, function_inputs=function_inputs
344+
)
345+
346+
# check if results are cached
347+
with contextlib.suppress(FunctionJobCacheNotFoundError):
348+
return await self.get_cached_function_job(
349+
function=function,
350+
job_inputs=job_inputs,
351+
)
352+
353+
pre_registered_function_job_data = (
354+
await self._function_job_service.pre_register_function_job(
355+
function=function,
356+
job_inputs=job_inputs,
357+
)
358+
)
359+
360+
# run function in celery task
361+
job_filter = AsyncJobFilter(
362+
user_id=user_identity.user_id,
363+
product_name=user_identity.product_name,
364+
client_name=ASYNC_JOB_CLIENT_NAME,
365+
)
366+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
367+
368+
task_uuid = await self._celery_task_manager.submit_task(
369+
TaskMetadata(
370+
name="run_function",
371+
ephemeral=True,
372+
queue=TasksQueue.API_WORKER_QUEUE,
373+
),
374+
task_filter=task_filter,
375+
user_identity=user_identity,
376+
function=function,
377+
pre_registered_function_job_data=pre_registered_function_job_data,
378+
pricing_spec=pricing_spec,
379+
job_links=job_links,
380+
x_simcore_parent_project_uuid=parent_project_uuid,
381+
x_simcore_parent_node_id=parent_node_id,
382+
)
383+
384+
return await self._function_job_service.patch_registered_function_job(
385+
user_id=user_identity.user_id,
386+
product_name=user_identity.product_name,
387+
function_job_id=pre_registered_function_job_data.function_job_id,
388+
function_class=function.function_class,
389+
job_creation_task_id=TaskID(task_uuid),
390+
)

services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import contextlib
2-
31
# pylint: disable=too-many-positional-arguments
42
from collections.abc import Callable
53
from typing import Annotated, Final, Literal
@@ -18,23 +16,17 @@
1816
RegisteredFunctionJob,
1917
RegisteredFunctionJobCollection,
2018
)
21-
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
2219
from models_library.functions import FunctionJobCollection, FunctionJobID
2320
from models_library.products import ProductName
2421
from models_library.projects import ProjectID
2522
from models_library.projects_nodes_io import NodeID
2623
from models_library.users import UserID
27-
from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue
2824
from servicelib.fastapi.dependencies import get_reverse_url_mapper
2925
from servicelib.utils import limited_gather
3026

3127
from ..._service_function_jobs import FunctionJobService
3228
from ..._service_function_jobs_task_client import FunctionJobTaskClientService
3329
from ..._service_functions import FunctionService
34-
from ...celery_worker.worker_tasks.functions_tasks import (
35-
run_function as run_function_task,
36-
)
37-
from ...exceptions.function_errors import FunctionJobCacheNotFoundError
3830
from ...models.pagination import Page, PaginationParams
3931
from ...models.schemas.errors import ErrorGet
4032
from ...models.schemas.jobs import JobPricingSpecification
@@ -45,7 +37,6 @@
4537
get_current_user_id,
4638
get_product_name,
4739
)
48-
from ..dependencies.celery import ASYNC_JOB_CLIENT_NAME, get_task_manager
4940
from ..dependencies.services import (
5041
get_function_job_service,
5142
get_function_job_task_client_service,
@@ -344,7 +335,6 @@ async def run_function(
344335
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
345336
) -> RegisteredFunctionJob:
346337
# preprocess inputs
347-
task_manager = get_task_manager(request.app)
348338
parent_project_uuid = (
349339
x_simcore_parent_project_uuid
350340
if isinstance(x_simcore_parent_project_uuid, ProjectID)
@@ -357,55 +347,15 @@ async def run_function(
357347
)
358348
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
359349
job_links = await function_service.get_function_job_links(to_run_function, url_for)
360-
job_inputs = await function_job_service.create_function_job_inputs(
361-
function=to_run_function, function_inputs=function_inputs
362-
)
363-
364-
# check if results are cached
365-
with contextlib.suppress(FunctionJobCacheNotFoundError):
366-
return await function_job_task_client_service.get_cached_function_job(
367-
function=to_run_function,
368-
job_inputs=job_inputs,
369-
)
370-
371-
pre_registered_function_job_data = (
372-
await function_job_service.pre_register_function_job(
373-
function=to_run_function,
374-
job_inputs=job_inputs,
375-
)
376-
)
377350

378-
# run function in celery task
379-
job_filter = AsyncJobFilter(
380-
user_id=user_identity.user_id,
381-
product_name=user_identity.product_name,
382-
client_name=ASYNC_JOB_CLIENT_NAME,
383-
)
384-
task_filter = TaskFilter.model_validate(job_filter.model_dump())
385-
task_name = run_function_task.__name__
386-
387-
task_uuid = await task_manager.submit_task(
388-
TaskMetadata(
389-
name=task_name,
390-
ephemeral=True,
391-
queue=TasksQueue.API_WORKER_QUEUE,
392-
),
393-
task_filter=task_filter,
394-
user_identity=user_identity,
351+
return await function_job_task_client_service.create_function_job_creation_task(
395352
function=to_run_function,
396-
pre_registered_function_job_data=pre_registered_function_job_data,
353+
function_inputs=function_inputs,
354+
user_identity=user_identity,
397355
pricing_spec=pricing_spec,
398356
job_links=job_links,
399-
x_simcore_parent_project_uuid=parent_project_uuid,
400-
x_simcore_parent_node_id=parent_node_id,
401-
)
402-
403-
return await function_job_service.patch_registered_function_job(
404-
user_id=user_identity.user_id,
405-
product_name=user_identity.product_name,
406-
function_job_id=pre_registered_function_job_data.function_job_id,
407-
function_class=to_run_function.function_class,
408-
job_creation_task_id=TaskID(task_uuid),
357+
parent_project_uuid=parent_project_uuid,
358+
parent_node_id=parent_node_id,
409359
)
410360

411361

@@ -463,19 +413,30 @@ async def map_function(
463413
x_simcore_parent_project_uuid: Annotated[ProjectID | Literal["null"], Header()],
464414
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
465415
) -> RegisteredFunctionJobCollection:
416+
parent_project_uuid = (
417+
x_simcore_parent_project_uuid
418+
if isinstance(x_simcore_parent_project_uuid, ProjectID)
419+
else None
420+
)
421+
parent_node_id = (
422+
x_simcore_parent_node_id
423+
if isinstance(x_simcore_parent_node_id, NodeID)
424+
else None
425+
)
426+
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
427+
job_links = await function_service.get_function_job_links(to_run_function, url_for)
466428

467429
async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID:
468-
result = await run_function(
469-
request=request,
470-
user_identity=user_identity,
471-
to_run_function=to_run_function,
472-
url_for=url_for,
473-
function_inputs=function_inputs,
474-
function_service=function_service,
475-
function_job_service=function_jobs_service,
476-
function_job_task_client_service=function_job_task_client_service,
477-
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
478-
x_simcore_parent_node_id=x_simcore_parent_node_id,
430+
result = (
431+
await function_job_task_client_service.create_function_job_creation_task(
432+
function=to_run_function,
433+
function_inputs=function_inputs,
434+
user_identity=user_identity,
435+
pricing_spec=pricing_spec,
436+
job_links=job_links,
437+
parent_project_uuid=parent_project_uuid,
438+
parent_node_id=parent_node_id,
439+
)
479440
)
480441
return result.uid
481442

services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@
3636
from servicelib.celery.models import TaskFilter, TaskState, TaskStatus, TaskUUID
3737
from simcore_service_api_server._meta import API_VTAG
3838
from simcore_service_api_server._service_function_jobs_task_client import (
39+
_JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS,
3940
_JOB_CREATION_TASK_STATUS_PREFIX,
4041
FunctionJobTaskClientService,
4142
)
42-
from simcore_service_api_server.api.routes import function_jobs_routes
43+
from simcore_service_api_server.api.dependencies import services as service_dependencies
4344
from simcore_service_api_server.models.schemas.jobs import JobStatus
4445

4546
_faker = Faker()
@@ -275,7 +276,9 @@ async def test_get_function_job_status(
275276
mocked_app_dependencies: None,
276277
client: AsyncClient,
277278
mocker: MockerFixture,
278-
mock_handler_in_functions_rpc_interface: Callable[[str, Any], None],
279+
mock_handler_in_functions_rpc_interface: Callable[
280+
[str, Any, Exception | None, Callable | None], MockType
281+
],
279282
mock_registered_project_function_job: RegisteredProjectFunctionJob,
280283
mock_registered_project_function: RegisteredProjectFunction,
281284
mock_method_in_jobs_service: Callable[[str, Any], None],
@@ -286,12 +289,12 @@ async def test_get_function_job_status(
286289
celery_task_state: TaskState,
287290
) -> None:
288291

289-
_expected_return_status = (
290-
status.HTTP_500_INTERNAL_SERVER_ERROR
291-
if job_status not in ("SUCCESS", "FAILED")
292-
and (project_job_id is None and job_creation_task_id is None)
293-
else status.HTTP_200_OK
294-
)
292+
_expected_return_status = status.HTTP_200_OK
293+
# status.HTTP_500_INTERNAL_SERVER_ERROR
294+
# if job_status not in ("SUCCESS", "FAILED")
295+
# and (project_job_id is None and job_creation_task_id is None)
296+
# else status.HTTP_200_OK
297+
# )
295298

296299
def _mock_task_manager(*args, **kwargs) -> CeleryTaskManager:
297300
async def _get_task_status(
@@ -318,7 +321,7 @@ async def _get_task_status(
318321
obj.get_task_status = _get_task_status
319322
return obj
320323

321-
mocker.patch.object(function_jobs_routes, "get_task_manager", _mock_task_manager)
324+
mocker.patch.object(service_dependencies, "get_task_manager", _mock_task_manager)
322325

323326
mock_handler_in_functions_rpc_interface(
324327
"get_function_job",
@@ -329,13 +332,14 @@ async def _get_task_status(
329332
"job_creation_task_id": job_creation_task_id,
330333
}
331334
),
335+
None,
336+
None,
332337
)
333338
mock_handler_in_functions_rpc_interface(
334-
"get_function", mock_registered_project_function
339+
"get_function", mock_registered_project_function, None, None
335340
)
336341
mock_handler_in_functions_rpc_interface(
337-
"get_function_job_status",
338-
FunctionJobStatus(status=job_status),
342+
"get_function_job_status", FunctionJobStatus(status=job_status), None, None
339343
)
340344
mock_method_in_jobs_service(
341345
"inspect_study_job",
@@ -347,25 +351,35 @@ async def _get_task_status(
347351
state=RunningState(value=job_status),
348352
),
349353
)
354+
355+
async def _update_function_job_status_side_effect(*args, **kwargs):
356+
return kwargs["job_status"]
357+
350358
mock_handler_in_functions_rpc_interface(
351359
"update_function_job_status",
352-
FunctionJobStatus(status=job_status),
360+
None,
361+
None,
362+
_update_function_job_status_side_effect,
353363
)
354364

355365
response = await client.get(
356366
f"{API_VTAG}/function_jobs/{mock_registered_project_function_job.uid}/status",
357367
auth=auth,
358368
)
359369
assert response.status_code == _expected_return_status
360-
if response.status_code == status.HTTP_200_OK:
361-
data = response.json()
362-
if project_job_id is not None or job_status in ("SUCCESS", "FAILED"):
363-
assert data["status"] == job_status
364-
else:
365-
assert (
366-
data["status"]
367-
== f"{_JOB_CREATION_TASK_STATUS_PREFIX}{celery_task_state}"
368-
)
370+
data = response.json()
371+
if (project_job_id is not None and job_creation_task_id is not None) or (
372+
job_status in ("SUCCESS", "FAILED")
373+
):
374+
assert data["status"] == job_status
375+
elif project_job_id is None and job_creation_task_id is None:
376+
assert data["status"] == _JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS
377+
elif project_job_id is None and job_creation_task_id is not None:
378+
assert (
379+
data["status"] == f"{_JOB_CREATION_TASK_STATUS_PREFIX}{celery_task_state}"
380+
)
381+
else:
382+
pytest.fail("Unexpected combination of parameters")
369383

370384

371385
@pytest.mark.parametrize(

services/api-server/tests/unit/api_functions/test_api_routers_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ async def test_run_map_function_not_allowed(
344344
"""Test that running a function is not allowed."""
345345

346346
mocker.patch(
347-
"simcore_service_api_server.api.routes.functions_routes.get_task_manager",
347+
"simcore_service_api_server.api.dependencies.services.get_task_manager",
348348
return_value=mocker.MagicMock(spec=CeleryTaskManager),
349349
)
350350

0 commit comments

Comments
 (0)