Skip to content

Commit 95dde06

Browse files
🎨 Run each job creation as a single task in api-worker (#8306)
1 parent ce96a9b commit 95dde06

File tree

5 files changed

+166
-134
lines changed

5 files changed

+166
-134
lines changed

services/api-server/src/simcore_service_api_server/_service_function_jobs.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ async def patch_registered_function_job(
405405
async def run_function(
406406
self,
407407
*,
408-
job_creation_task_id: TaskID | None,
409408
function: RegisteredFunction,
410409
pre_registered_function_job_data: PreRegisteredFunctionJobData,
411410
pricing_spec: JobPricingSpecification | None,
@@ -434,7 +433,7 @@ async def run_function(
434433
product_name=self.product_name,
435434
function_job_id=pre_registered_function_job_data.function_job_id,
436435
function_class=FunctionClass.PROJECT,
437-
job_creation_task_id=job_creation_task_id,
436+
job_creation_task_id=None,
438437
project_job_id=study_job.id,
439438
)
440439

@@ -459,37 +458,14 @@ async def run_function(
459458
product_name=self.product_name,
460459
function_job_id=pre_registered_function_job_data.function_job_id,
461460
function_class=FunctionClass.SOLVER,
462-
job_creation_task_id=job_creation_task_id,
461+
job_creation_task_id=None,
463462
solver_job_id=solver_job.id,
464463
)
465464

466465
raise UnsupportedFunctionClassError(
467466
function_class=function.function_class,
468467
)
469468

470-
async def map_function(
471-
self,
472-
*,
473-
job_creation_task_id: TaskID | None,
474-
function: RegisteredFunction,
475-
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
476-
job_links: JobLinks,
477-
pricing_spec: JobPricingSpecification | None,
478-
x_simcore_parent_project_uuid: ProjectID | None,
479-
x_simcore_parent_node_id: NodeID | None,
480-
) -> None:
481-
482-
for data in pre_registered_function_job_data_list:
483-
await self.run_function(
484-
job_creation_task_id=job_creation_task_id,
485-
function=function,
486-
pre_registered_function_job_data=data,
487-
pricing_spec=pricing_spec,
488-
job_links=job_links,
489-
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
490-
x_simcore_parent_node_id=x_simcore_parent_node_id,
491-
)
492-
493469
async def function_job_outputs(
494470
self,
495471
*,

services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py

Lines changed: 29 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@
2626
from models_library.users import UserID
2727
from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue
2828
from servicelib.fastapi.dependencies import get_reverse_url_mapper
29+
from servicelib.utils import limited_gather
2930

3031
from ..._service_function_jobs import FunctionJobService
3132
from ..._service_functions import FunctionService
32-
from ...celery_worker.worker_tasks.functions_tasks import function_map as map_task
3333
from ...celery_worker.worker_tasks.functions_tasks import (
3434
run_function as run_function_task,
3535
)
3636
from ...exceptions.function_errors import FunctionJobCacheNotFoundError
37-
from ...models.domain.functions import PreRegisteredFunctionJobData
3837
from ...models.pagination import Page, PaginationParams
3938
from ...models.schemas.errors import ErrorGet
4039
from ...models.schemas.jobs import JobPricingSpecification
@@ -454,86 +453,42 @@ async def map_function(
454453
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
455454
) -> RegisteredFunctionJobCollection:
456455

457-
task_manager = get_task_manager(request.app)
458-
parent_project_uuid = (
459-
x_simcore_parent_project_uuid
460-
if isinstance(x_simcore_parent_project_uuid, ProjectID)
461-
else None
462-
)
463-
parent_node_id = (
464-
x_simcore_parent_node_id
465-
if isinstance(x_simcore_parent_node_id, NodeID)
466-
else None
467-
)
468-
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
469-
job_links = await function_service.get_function_job_links(to_run_function, url_for)
470-
471-
job_inputs_list = [
472-
await function_jobs_service.create_function_job_inputs(
473-
function=to_run_function, function_inputs=function_inputs
456+
async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID:
457+
result = await run_function(
458+
request=request,
459+
user_identity=user_identity,
460+
to_run_function=to_run_function,
461+
url_for=url_for,
462+
function_inputs=function_inputs,
463+
function_service=function_service,
464+
function_job_service=function_jobs_service,
465+
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
466+
x_simcore_parent_node_id=x_simcore_parent_node_id,
474467
)
475-
for function_inputs in function_inputs_list
476-
]
477-
478-
job_ids: list[FunctionJobID] = []
479-
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData] = []
480-
481-
for job_inputs in job_inputs_list:
482-
try:
483-
cached_job = await function_jobs_service.get_cached_function_job(
484-
function=to_run_function,
485-
job_inputs=job_inputs,
486-
)
487-
job_ids.append(cached_job.uid)
488-
except FunctionJobCacheNotFoundError:
489-
data = await function_jobs_service.pre_register_function_job(
490-
function=to_run_function,
491-
job_inputs=job_inputs,
492-
)
493-
pre_registered_function_job_data_list.append(data)
494-
job_ids.append(data.function_job_id)
495-
496-
# run map in celery task
497-
job_filter = AsyncJobFilter(
498-
user_id=user_identity.user_id,
499-
product_name=user_identity.product_name,
500-
client_name=ASYNC_JOB_CLIENT_NAME,
501-
)
502-
task_filter = TaskFilter.model_validate(job_filter.model_dump())
503-
task_name = map_task.__name__
504-
505-
task_uuid = await task_manager.submit_task(
506-
TaskMetadata(
507-
name=task_name,
508-
ephemeral=True,
509-
queue=TasksQueue.API_WORKER_QUEUE,
510-
),
511-
task_filter=task_filter,
512-
user_identity=user_identity,
513-
function=to_run_function,
514-
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
515-
pricing_spec=pricing_spec,
516-
job_links=job_links,
517-
x_simcore_parent_project_uuid=parent_project_uuid,
518-
x_simcore_parent_node_id=parent_node_id,
468+
return result.uid
469+
470+
# Run all tasks concurrently, allowing them to complete even if some fail
471+
results = await limited_gather(
472+
*[
473+
_run_single_function(function_inputs)
474+
for function_inputs in function_inputs_list
475+
],
476+
reraise=False,
477+
limit=10,
519478
)
520479

521-
# patch pre-registered function jobs
522-
for data in pre_registered_function_job_data_list:
523-
await function_jobs_service.patch_registered_function_job(
524-
user_id=user_identity.user_id,
525-
product_name=user_identity.product_name,
526-
function_job_id=data.function_job_id,
527-
function_class=to_run_function.function_class,
528-
job_creation_task_id=TaskID(task_uuid),
529-
)
480+
# Check if any tasks raised exceptions and raise the first one found
481+
for result in results:
482+
if isinstance(result, BaseException):
483+
raise result
530484

531-
function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(pre_registered_function_job_data_list)} inputs"
485+
# At this point, all results are FunctionJobID since we've checked for exceptions
486+
function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs"
532487
return await web_api_rpc_client.register_function_job_collection(
533488
function_job_collection=FunctionJobCollection(
534489
title="Function job collection of function map",
535490
description=function_job_collection_description,
536-
job_ids=job_ids,
491+
job_ids=results, # type: ignore
537492
),
538493
user_id=user_identity.user_id,
539494
product_name=user_identity.product_name,

services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -120,40 +120,10 @@ async def run_function(
120120
)
121121

122122
return await function_job_service.run_function(
123-
job_creation_task_id=task_id,
124123
function=function,
125124
pre_registered_function_job_data=pre_registered_function_job_data,
126125
pricing_spec=pricing_spec,
127126
job_links=job_links,
128127
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
129128
x_simcore_parent_node_id=x_simcore_parent_node_id,
130129
)
131-
132-
133-
async def function_map(
134-
task: Task,
135-
task_id: TaskID,
136-
*,
137-
user_identity: Identity,
138-
function: RegisteredFunction,
139-
pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData],
140-
job_links: JobLinks,
141-
pricing_spec: JobPricingSpecification | None,
142-
x_simcore_parent_project_uuid: ProjectID | None,
143-
x_simcore_parent_node_id: NodeID | None,
144-
) -> None:
145-
assert task_id # nosec
146-
app = get_app_server(task.app).app
147-
function_job_service = await _assemble_function_job_service(
148-
app=app, user_identity=user_identity
149-
)
150-
151-
return await function_job_service.map_function(
152-
job_creation_task_id=task_id,
153-
function=function,
154-
pre_registered_function_job_data_list=pre_registered_function_job_data_list,
155-
pricing_spec=pricing_spec,
156-
job_links=job_links,
157-
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
158-
x_simcore_parent_node_id=x_simcore_parent_node_id,
159-
)

services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ...models.api_resources import JobLinks
1919
from ...models.domain.functions import PreRegisteredFunctionJobData
2020
from ...models.schemas.jobs import JobInputs, JobPricingSpecification
21-
from .functions_tasks import function_map, run_function
21+
from .functions_tasks import run_function
2222

2323
_logger = logging.getLogger(__name__)
2424

@@ -43,4 +43,3 @@ def setup_worker_tasks(app: Celery) -> None:
4343

4444
with log_context(_logger, logging.INFO, msg="worker task registration"):
4545
register_task(app, run_function)
46-
register_task(app, function_map)

services/api-server/tests/unit/api_functions/celery/test_functions_celery.py

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ async def _patch_registered_function_job_side_effect(
150150
registered_function_job_patch = kwargs["registered_function_job_patch"]
151151
assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch)
152152
job_creation_task_id = registered_function_job_patch.job_creation_task_id
153-
assert job_creation_task_id is not None
153+
uid = kwargs["function_job_uuid"]
154154
return mock_registered_project_function_job.model_copy(
155-
update={"job_creation_task_id": job_creation_task_id}
155+
update={"job_creation_task_id": job_creation_task_id, "uid": uid}
156156
)
157157

158158

@@ -551,3 +551,135 @@ def _default_side_effect(
551551
].job_creation_task_id
552552
await wait_for_task_result(client, auth, f"{task_id}")
553553
assert side_effect_checks["headers_checked"] is True
554+
555+
556+
@pytest.mark.parametrize("mocked_app_dependencies", [None])
557+
async def test_map_function(
558+
app: FastAPI,
559+
with_api_server_celery_worker: TestWorkController,
560+
client: AsyncClient,
561+
mock_handler_in_functions_rpc_interface: Callable[
562+
[str, Any, Exception | None, Callable | None], MockType
563+
],
564+
mock_registered_project_function: RegisteredProjectFunction,
565+
mock_registered_project_function_job: RegisteredFunctionJob,
566+
auth: httpx.BasicAuth,
567+
user_id: UserID,
568+
mocked_webserver_rest_api_base: respx.MockRouter,
569+
mocked_directorv2_rest_api_base: respx.MockRouter,
570+
mocked_webserver_rpc_api: dict[str, MockType],
571+
create_respx_mock_from_capture,
572+
project_tests_dir: Path,
573+
) -> None:
574+
575+
# arrange
576+
_capture = "run_study_function_parent_info.json"
577+
578+
def _default_side_effect(
579+
request: httpx.Request,
580+
path_params: dict[str, Any],
581+
capture: HttpApiCallCaptureModel,
582+
) -> Any:
583+
return capture.response_body
584+
585+
create_respx_mock_from_capture(
586+
respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base],
587+
capture_path=project_tests_dir / "mocks" / _capture,
588+
side_effects_callbacks=[_default_side_effect] * 50,
589+
)
590+
591+
mock_handler_in_functions_rpc_interface(
592+
"get_function_user_permissions",
593+
FunctionUserAccessRights(
594+
user_id=user_id,
595+
execute=True,
596+
read=True,
597+
write=True,
598+
),
599+
None,
600+
None,
601+
)
602+
mock_handler_in_functions_rpc_interface(
603+
"get_function", mock_registered_project_function, None, None
604+
)
605+
mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None)
606+
607+
_generated_function_job_ids: list[FunctionJobID] = []
608+
609+
async def _register_function_job_side_effect(
610+
generated_function_job_ids: list[FunctionJobID], *args, **kwargs
611+
):
612+
uid = FunctionJobID(_faker.uuid4())
613+
generated_function_job_ids.append(uid)
614+
return mock_registered_project_function_job.model_copy(update={"uid": uid})
615+
616+
mock_handler_in_functions_rpc_interface(
617+
"register_function_job",
618+
None,
619+
None,
620+
partial(_register_function_job_side_effect, _generated_function_job_ids),
621+
)
622+
mock_handler_in_functions_rpc_interface(
623+
"get_functions_user_api_access_rights",
624+
FunctionUserApiAccessRights(
625+
user_id=user_id,
626+
execute_functions=True,
627+
write_functions=True,
628+
read_functions=True,
629+
),
630+
None,
631+
None,
632+
)
633+
634+
async def _register_function_job_collection_side_effect(*args, **kwargs):
635+
job_collection = kwargs["function_job_collection"]
636+
return RegisteredFunctionJobCollection(
637+
uid=FunctionJobID(_faker.uuid4()),
638+
title="Test Collection",
639+
description="A test function job collection",
640+
job_ids=job_collection.job_ids,
641+
created_at=datetime.datetime.now(datetime.UTC),
642+
)
643+
644+
mock_handler_in_functions_rpc_interface(
645+
"register_function_job_collection",
646+
None,
647+
None,
648+
_register_function_job_collection_side_effect,
649+
)
650+
651+
patch_mock = mock_handler_in_functions_rpc_interface(
652+
"patch_registered_function_job",
653+
None,
654+
None,
655+
partial(
656+
_patch_registered_function_job_side_effect,
657+
mock_registered_project_function_job,
658+
),
659+
)
660+
661+
# act
662+
_inputs = [{}, {}]
663+
response = await client.post(
664+
f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map",
665+
json=_inputs,
666+
auth=auth,
667+
headers={
668+
X_SIMCORE_PARENT_PROJECT_UUID: "null",
669+
X_SIMCORE_PARENT_NODE_ID: "null",
670+
},
671+
)
672+
673+
# assert
674+
assert response.status_code == status.HTTP_200_OK
675+
job_collection = FunctionJobCollection.model_validate(response.json())
676+
assert (
677+
job_collection.job_ids == _generated_function_job_ids
678+
), "Job ID did not preserve order or were incorrectly propagated"
679+
celery_task_ids = {
680+
elm.kwargs["registered_function_job_patch"].job_creation_task_id
681+
for elm in patch_mock.call_args_list
682+
}
683+
assert len(celery_task_ids) == len(_inputs)
684+
for task_id in celery_task_ids:
685+
await wait_for_task_result(client, auth, f"{task_id}")

0 commit comments

Comments
 (0)