diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py index e1b3728e7eb1..e89d910d47a0 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py @@ -405,7 +405,6 @@ async def patch_registered_function_job( async def run_function( self, *, - job_creation_task_id: TaskID | None, function: RegisteredFunction, pre_registered_function_job_data: PreRegisteredFunctionJobData, pricing_spec: JobPricingSpecification | None, @@ -434,7 +433,7 @@ async def run_function( product_name=self.product_name, function_job_id=pre_registered_function_job_data.function_job_id, function_class=FunctionClass.PROJECT, - job_creation_task_id=job_creation_task_id, + job_creation_task_id=None, project_job_id=study_job.id, ) @@ -459,7 +458,7 @@ async def run_function( product_name=self.product_name, function_job_id=pre_registered_function_job_data.function_job_id, function_class=FunctionClass.SOLVER, - job_creation_task_id=job_creation_task_id, + job_creation_task_id=None, solver_job_id=solver_job.id, ) @@ -467,29 +466,6 @@ async def run_function( function_class=function.function_class, ) - async def map_function( - self, - *, - job_creation_task_id: TaskID | None, - function: RegisteredFunction, - pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData], - job_links: JobLinks, - pricing_spec: JobPricingSpecification | None, - x_simcore_parent_project_uuid: ProjectID | None, - x_simcore_parent_node_id: NodeID | None, - ) -> None: - - for data in pre_registered_function_job_data_list: - await self.run_function( - job_creation_task_id=job_creation_task_id, - function=function, - pre_registered_function_job_data=data, - pricing_spec=pricing_spec, - job_links=job_links, - x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, - x_simcore_parent_node_id=x_simcore_parent_node_id, - ) - async def function_job_outputs( self, *, diff --git a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py index 026b7d100c48..8ac3589e9f46 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py @@ -26,15 +26,14 @@ from models_library.users import UserID from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata, TasksQueue from servicelib.fastapi.dependencies import get_reverse_url_mapper +from servicelib.utils import limited_gather from ..._service_function_jobs import FunctionJobService from ..._service_functions import FunctionService -from ...celery_worker.worker_tasks.functions_tasks import function_map as map_task from ...celery_worker.worker_tasks.functions_tasks import ( run_function as run_function_task, ) from ...exceptions.function_errors import FunctionJobCacheNotFoundError -from ...models.domain.functions import PreRegisteredFunctionJobData from ...models.pagination import Page, PaginationParams from ...models.schemas.errors import ErrorGet from ...models.schemas.jobs import JobPricingSpecification @@ -454,86 +453,42 @@ async def map_function( x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()], ) -> RegisteredFunctionJobCollection: - task_manager = get_task_manager(request.app) - parent_project_uuid = ( - x_simcore_parent_project_uuid - if isinstance(x_simcore_parent_project_uuid, ProjectID) - else None - ) - parent_node_id = ( - x_simcore_parent_node_id - if isinstance(x_simcore_parent_node_id, NodeID) - else None - ) - pricing_spec = JobPricingSpecification.create_from_headers(request.headers) - job_links = await function_service.get_function_job_links(to_run_function, url_for) - - job_inputs_list = [ - await function_jobs_service.create_function_job_inputs( - function=to_run_function, function_inputs=function_inputs + async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID: + result = await run_function( + request=request, + user_identity=user_identity, + to_run_function=to_run_function, + url_for=url_for, + function_inputs=function_inputs, + function_service=function_service, + function_job_service=function_jobs_service, + x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, + x_simcore_parent_node_id=x_simcore_parent_node_id, ) - for function_inputs in function_inputs_list - ] - - job_ids: list[FunctionJobID] = [] - pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData] = [] - - for job_inputs in job_inputs_list: - try: - cached_job = await function_jobs_service.get_cached_function_job( - function=to_run_function, - job_inputs=job_inputs, - ) - job_ids.append(cached_job.uid) - except FunctionJobCacheNotFoundError: - data = await function_jobs_service.pre_register_function_job( - function=to_run_function, - job_inputs=job_inputs, - ) - pre_registered_function_job_data_list.append(data) - job_ids.append(data.function_job_id) - - # run map in celery task - job_filter = AsyncJobFilter( - user_id=user_identity.user_id, - product_name=user_identity.product_name, - client_name=ASYNC_JOB_CLIENT_NAME, - ) - task_filter = TaskFilter.model_validate(job_filter.model_dump()) - task_name = map_task.__name__ - - task_uuid = await task_manager.submit_task( - TaskMetadata( - name=task_name, - ephemeral=True, - queue=TasksQueue.API_WORKER_QUEUE, - ), - task_filter=task_filter, - user_identity=user_identity, - function=to_run_function, - pre_registered_function_job_data_list=pre_registered_function_job_data_list, - pricing_spec=pricing_spec, - job_links=job_links, - x_simcore_parent_project_uuid=parent_project_uuid, - x_simcore_parent_node_id=parent_node_id, + return result.uid + + # Run all tasks concurrently, allowing them to complete even if some fail + results = await limited_gather( + *[ + _run_single_function(function_inputs) + for function_inputs in function_inputs_list + ], + reraise=False, + limit=10, ) - # patch pre-registered function jobs - for data in pre_registered_function_job_data_list: - await function_jobs_service.patch_registered_function_job( - user_id=user_identity.user_id, - product_name=user_identity.product_name, - function_job_id=data.function_job_id, - function_class=to_run_function.function_class, - job_creation_task_id=TaskID(task_uuid), - ) + # Check if any tasks raised exceptions and raise the first one found + for result in results: + if isinstance(result, BaseException): + raise result - function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(pre_registered_function_job_data_list)} inputs" + # At this point, all results are FunctionJobID since we've checked for exceptions + function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs" return await web_api_rpc_client.register_function_job_collection( function_job_collection=FunctionJobCollection( title="Function job collection of function map", description=function_job_collection_description, - job_ids=job_ids, + job_ids=results, # type: ignore ), user_id=user_identity.user_id, product_name=user_identity.product_name, diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py index e8cde1c5dc55..fdfc8f1382b1 100644 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/functions_tasks.py @@ -120,7 +120,6 @@ async def run_function( ) return await function_job_service.run_function( - job_creation_task_id=task_id, function=function, pre_registered_function_job_data=pre_registered_function_job_data, pricing_spec=pricing_spec, @@ -128,32 +127,3 @@ async def run_function( x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, x_simcore_parent_node_id=x_simcore_parent_node_id, ) - - -async def function_map( - task: Task, - task_id: TaskID, - *, - user_identity: Identity, - function: RegisteredFunction, - pre_registered_function_job_data_list: list[PreRegisteredFunctionJobData], - job_links: JobLinks, - pricing_spec: JobPricingSpecification | None, - x_simcore_parent_project_uuid: ProjectID | None, - x_simcore_parent_node_id: NodeID | None, -) -> None: - assert task_id # nosec - app = get_app_server(task.app).app - function_job_service = await _assemble_function_job_service( - app=app, user_identity=user_identity - ) - - return await function_job_service.map_function( - job_creation_task_id=task_id, - function=function, - pre_registered_function_job_data_list=pre_registered_function_job_data_list, - pricing_spec=pricing_spec, - job_links=job_links, - x_simcore_parent_project_uuid=x_simcore_parent_project_uuid, - x_simcore_parent_node_id=x_simcore_parent_node_id, - ) diff --git a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py index d51d4a4cdba5..b212a04dc283 100644 --- a/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py +++ b/services/api-server/src/simcore_service_api_server/celery_worker/worker_tasks/tasks.py @@ -18,7 +18,7 @@ from ...models.api_resources import JobLinks from ...models.domain.functions import PreRegisteredFunctionJobData from ...models.schemas.jobs import JobInputs, JobPricingSpecification -from .functions_tasks import function_map, run_function +from .functions_tasks import run_function _logger = logging.getLogger(__name__) @@ -43,4 +43,3 @@ def setup_worker_tasks(app: Celery) -> None: with log_context(_logger, logging.INFO, msg="worker task registration"): register_task(app, run_function) - register_task(app, function_map) diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index 9aa57bcabcc3..c28685ee2f6c 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -150,9 +150,9 @@ async def _patch_registered_function_job_side_effect( registered_function_job_patch = kwargs["registered_function_job_patch"] assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch) job_creation_task_id = registered_function_job_patch.job_creation_task_id - assert job_creation_task_id is not None + uid = kwargs["function_job_uuid"] return mock_registered_project_function_job.model_copy( - update={"job_creation_task_id": job_creation_task_id} + update={"job_creation_task_id": job_creation_task_id, "uid": uid} ) @@ -551,3 +551,135 @@ def _default_side_effect( ].job_creation_task_id await wait_for_task_result(client, auth, f"{task_id}") assert side_effect_checks["headers_checked"] is True + + +@pytest.mark.parametrize("mocked_app_dependencies", [None]) +async def test_map_function( + app: FastAPI, + with_api_server_celery_worker: TestWorkController, + client: AsyncClient, + mock_handler_in_functions_rpc_interface: Callable[ + [str, Any, Exception | None, Callable | None], MockType + ], + mock_registered_project_function: RegisteredProjectFunction, + mock_registered_project_function_job: RegisteredFunctionJob, + auth: httpx.BasicAuth, + user_id: UserID, + mocked_webserver_rest_api_base: respx.MockRouter, + mocked_directorv2_rest_api_base: respx.MockRouter, + mocked_webserver_rpc_api: dict[str, MockType], + create_respx_mock_from_capture, + project_tests_dir: Path, +) -> None: + + # arrange + _capture = "run_study_function_parent_info.json" + + def _default_side_effect( + request: httpx.Request, + path_params: dict[str, Any], + capture: HttpApiCallCaptureModel, + ) -> Any: + return capture.response_body + + create_respx_mock_from_capture( + respx_mocks=[mocked_webserver_rest_api_base, mocked_directorv2_rest_api_base], + capture_path=project_tests_dir / "mocks" / _capture, + side_effects_callbacks=[_default_side_effect] * 50, + ) + + mock_handler_in_functions_rpc_interface( + "get_function_user_permissions", + FunctionUserAccessRights( + user_id=user_id, + execute=True, + read=True, + write=True, + ), + None, + None, + ) + mock_handler_in_functions_rpc_interface( + "get_function", mock_registered_project_function, None, None + ) + mock_handler_in_functions_rpc_interface("find_cached_function_jobs", [], None, None) + + _generated_function_job_ids: list[FunctionJobID] = [] + + async def _register_function_job_side_effect( + generated_function_job_ids: list[FunctionJobID], *args, **kwargs + ): + uid = FunctionJobID(_faker.uuid4()) + generated_function_job_ids.append(uid) + return mock_registered_project_function_job.model_copy(update={"uid": uid}) + + mock_handler_in_functions_rpc_interface( + "register_function_job", + None, + None, + partial(_register_function_job_side_effect, _generated_function_job_ids), + ) + mock_handler_in_functions_rpc_interface( + "get_functions_user_api_access_rights", + FunctionUserApiAccessRights( + user_id=user_id, + execute_functions=True, + write_functions=True, + read_functions=True, + ), + None, + None, + ) + + async def _register_function_job_collection_side_effect(*args, **kwargs): + job_collection = kwargs["function_job_collection"] + return RegisteredFunctionJobCollection( + uid=FunctionJobID(_faker.uuid4()), + title="Test Collection", + description="A test function job collection", + job_ids=job_collection.job_ids, + created_at=datetime.datetime.now(datetime.UTC), + ) + + mock_handler_in_functions_rpc_interface( + "register_function_job_collection", + None, + None, + _register_function_job_collection_side_effect, + ) + + patch_mock = mock_handler_in_functions_rpc_interface( + "patch_registered_function_job", + None, + None, + partial( + _patch_registered_function_job_side_effect, + mock_registered_project_function_job, + ), + ) + + # act + _inputs = [{}, {}] + response = await client.post( + f"{API_VTAG}/functions/{mock_registered_project_function.uid}:map", + json=_inputs, + auth=auth, + headers={ + X_SIMCORE_PARENT_PROJECT_UUID: "null", + X_SIMCORE_PARENT_NODE_ID: "null", + }, + ) + + # assert + assert response.status_code == status.HTTP_200_OK + job_collection = FunctionJobCollection.model_validate(response.json()) + assert ( + job_collection.job_ids == _generated_function_job_ids + ), "Job ID did not preserve order or were incorrectly propagated" + celery_task_ids = { + elm.kwargs["registered_function_job_patch"].job_creation_task_id + for elm in patch_mock.call_args_list + } + assert len(celery_task_ids) == len(_inputs) + for task_id in celery_task_ids: + await wait_for_task_result(client, auth, f"{task_id}")