diff --git a/packages/models-library/src/models_library/batch_operations.py b/packages/models-library/src/models_library/batch_operations.py index a757ea45a447..ff0b34b687a4 100644 --- a/packages/models-library/src/models_library/batch_operations.py +++ b/packages/models-library/src/models_library/batch_operations.py @@ -85,3 +85,37 @@ class BatchGetEnvelope(BaseModel, Generic[ResourceT, IdentifierT]): description="List of identifiers for items that were not found", ), ] = DEFAULT_FACTORY + + +class BatchCreateEnvelope(BaseModel, Generic[SchemaT]): + """Generic envelope model for batch-create operations. + + This model represents the result of a strict batch create operation, + containing the list of created items. The operation is expected to be "strict" + in the sense that it either creates all requested items or fails entirely. + """ + + created_items: Annotated[ + list[SchemaT], + Field( + min_length=1, + description="List of successfully created items", + ), + ] + + +class BatchUpdateEnvelope(BaseModel, Generic[SchemaT]): + """Generic envelope model for batch-update operations. + + This model represents the result of a strict batch update operation, + containing the list of updated items. The operation is expected to be "strict" + in the sense that it either updates all requested items or fails entirely. See https://google.aip.dev/234 + """ + + updated_items: Annotated[ + list[SchemaT], + Field( + min_length=1, + description="List of successfully updated items", + ), + ] diff --git a/packages/models-library/src/models_library/functions.py b/packages/models-library/src/models_library/functions.py index f8a0f68d9255..317a9df5b2bf 100644 --- a/packages/models-library/src/models_library/functions.py +++ b/packages/models-library/src/models_library/functions.py @@ -1,12 +1,13 @@ import datetime from collections.abc import Mapping from enum import Enum -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, Final, Literal, TypeAlias from uuid import UUID from models_library import projects from models_library.basic_regex import UUID_RE_BASE from models_library.basic_types import ConstrainedStr +from models_library.batch_operations import BatchCreateEnvelope from models_library.groups import GroupID from models_library.products import ProductName from models_library.services_types import ServiceKey, ServiceVersion @@ -14,6 +15,7 @@ from models_library.utils.enums import StrAutoEnum from pydantic import BaseModel, ConfigDict, Field +from .batch_operations import BatchGetEnvelope, BatchUpdateEnvelope from .projects import ProjectID from .utils.change_case import snake_to_camel @@ -23,6 +25,7 @@ FileID: TypeAlias = UUID InputTypes: TypeAlias = FileID | float | int | bool | str | list +_MAX_LIST_LENGTH: Final[int] = 50 class FunctionSchemaClass(str, Enum): @@ -80,9 +83,10 @@ class FunctionClass(str, Enum): FunctionInputsList: TypeAlias = Annotated[ list[FunctionInputs], - Field(max_length=50), + Field(max_length=_MAX_LIST_LENGTH), ] + FunctionOutputs: TypeAlias = dict[str, Any] | None FunctionOutputsLogfile: TypeAlias = Any @@ -238,6 +242,9 @@ class RegisteredPythonCodeFunctionJobPatch(BaseModel): ProjectFunctionJob | PythonCodeFunctionJob | SolverFunctionJob, Field(discriminator="function_class"), ] +FunctionJobList: TypeAlias = Annotated[ + list[FunctionJob], Field(max_length=_MAX_LIST_LENGTH) +] class RegisteredFunctionJobBase(FunctionJobBase): @@ -264,6 +271,21 @@ class RegisteredPythonCodeFunctionJob(PythonCodeFunctionJob, RegisteredFunctionJ Field(discriminator="function_class"), ] + +class BatchCreateRegisteredFunctionJobs(BatchCreateEnvelope[RegisteredFunctionJob]): + pass + + +class BatchUpdateRegisteredFunctionJobs(BatchUpdateEnvelope[RegisteredFunctionJob]): + pass + + +class BatchGetCachedRegisteredFunctionJobs( + BatchGetEnvelope[RegisteredFunctionJob, FunctionInputs] +): + pass + + RegisteredFunctionJobPatch = Annotated[ RegisteredProjectFunctionJobPatch | RegisteredPythonCodeFunctionJobPatch @@ -272,6 +294,20 @@ class RegisteredPythonCodeFunctionJob(PythonCodeFunctionJob, RegisteredFunctionJ ] +class FunctionJobPatchRequest(BaseModel): + uid: FunctionJobID + patch: RegisteredFunctionJobPatch + + +FunctionJobPatchRequestList: TypeAlias = Annotated[ + list[FunctionJobPatchRequest], + Field( + max_length=_MAX_LIST_LENGTH, + description="List of function job patch requests", + ), +] + + class FunctionJobStatus(BaseModel): status: str @@ -340,6 +376,20 @@ class RegisteredFunctionJobDB(FunctionJobDB): created: datetime.datetime +class BatchGetCachedRegisteredFunctionJobsDB( + BatchGetEnvelope[RegisteredFunctionJobDB, FunctionInputs] +): + pass + + +class BatchCreateRegisteredFunctionJobsDB(BatchCreateEnvelope[RegisteredFunctionJobDB]): + pass + + +class BatchUpdateRegisteredFunctionJobsDB(BatchUpdateEnvelope[RegisteredFunctionJobDB]): + pass + + class RegisteredFunctionJobWithStatusDB(FunctionJobDB): uuid: FunctionJobID created: datetime.datetime diff --git a/packages/models-library/src/models_library/functions_errors.py b/packages/models-library/src/models_library/functions_errors.py index 629c3dc1c7e9..1f2762f0d4c2 100644 --- a/packages/models-library/src/models_library/functions_errors.py +++ b/packages/models-library/src/models_library/functions_errors.py @@ -170,3 +170,8 @@ class FunctionJobCollectionsExecuteApiAccessDeniedError(FunctionBaseError): class FunctionJobPatchModelIncompatibleError(FunctionBaseError): msg_template = "Incompatible patch model for Function '{function_id}' in product '{product_name}'." status_code: int = 422 + + +class FunctionUnrecoverableError(FunctionBaseError): + msg_template = "Unrecoverable error." + status_code: int = 500 diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/functions/functions_rpc_interface.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/functions/functions_rpc_interface.py index ad4a7295351b..6a6650067bfc 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/functions/functions_rpc_interface.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/functions/functions_rpc_interface.py @@ -21,11 +21,11 @@ from models_library.functions import ( FunctionClass, FunctionGroupAccessRights, + FunctionInputsList, FunctionJobStatus, FunctionOutputs, FunctionUserAccessRights, FunctionUserApiAccessRights, - RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatus, ) from models_library.products import ProductName @@ -359,29 +359,6 @@ async def register_function_job( ) # Validates the result as a RegisteredFunctionJob -@log_decorator(_logger, level=logging.DEBUG) -async def patch_registered_function_job( - rabbitmq_rpc_client: RabbitMQRPCClient, - *, - user_id: UserID, - product_name: ProductName, - function_job_uuid: FunctionJobID, - registered_function_job_patch: RegisteredFunctionJobPatch, -) -> RegisteredFunctionJob: - result = await rabbitmq_rpc_client.request( - DEFAULT_WEBSERVER_RPC_NAMESPACE, - TypeAdapter(RPCMethodName).validate_python("patch_registered_function_job"), - user_id=user_id, - product_name=product_name, - function_job_uuid=function_job_uuid, - registered_function_job_patch=registered_function_job_patch, - timeout_s=_FUNCTION_RPC_TIMEOUT_SEC, - ) - return TypeAdapter(RegisteredFunctionJob).validate_python( - result - ) # Validates the result as a RegisteredFunctionJob - - @log_decorator(_logger, level=logging.DEBUG) async def get_function_job( rabbitmq_rpc_client: RabbitMQRPCClient, @@ -512,20 +489,20 @@ async def find_cached_function_jobs( user_id: UserID, product_name: ProductName, function_id: FunctionID, - inputs: FunctionInputs, -) -> list[RegisteredFunctionJob] | None: + inputs: FunctionInputsList, + status_filter: list[FunctionJobStatus] | None = None, +) -> list[RegisteredFunctionJob | None]: result = await rabbitmq_rpc_client.request( DEFAULT_WEBSERVER_RPC_NAMESPACE, TypeAdapter(RPCMethodName).validate_python("find_cached_function_jobs"), function_id=function_id, inputs=inputs, + status_filter=status_filter, user_id=user_id, product_name=product_name, timeout_s=_FUNCTION_RPC_TIMEOUT_SEC, ) - if result is None: - return None - return TypeAdapter(list[RegisteredFunctionJob]).validate_python(result) + return TypeAdapter(list[RegisteredFunctionJob | None]).validate_python(result) @log_decorator(_logger, level=logging.DEBUG) diff --git a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/v1/functions.py b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/v1/functions.py index f96c6fb1952e..4fd249c2ec82 100644 --- a/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/v1/functions.py +++ b/packages/service-library/src/servicelib/rabbitmq/rpc_interfaces/webserver/v1/functions.py @@ -7,7 +7,6 @@ FunctionID, FunctionInputs, FunctionInputSchema, - FunctionJob, FunctionJobCollection, FunctionJobCollectionID, FunctionJobCollectionsListFilters, @@ -18,13 +17,19 @@ RegisteredFunctionJobCollection, ) from models_library.functions import ( + BatchCreateRegisteredFunctionJobs, + BatchUpdateRegisteredFunctionJobs, FunctionClass, FunctionGroupAccessRights, + FunctionInputsList, + FunctionJob, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, FunctionJobStatus, FunctionOutputs, FunctionUserAccessRights, FunctionUserApiAccessRights, - RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatus, ) from models_library.products import ProductName @@ -329,13 +334,29 @@ async def register_function_job( ), ) + async def batch_register_function_jobs( + self, + *, + product_name: ProductName, + user_id: UserID, + function_jobs: FunctionJobList, + ) -> BatchCreateRegisteredFunctionJobs: + """Register a function job.""" + return TypeAdapter(BatchCreateRegisteredFunctionJobs).validate_python( + await self._request( + "batch_register_function_jobs", + product_name=product_name, + user_id=user_id, + function_jobs=function_jobs, + ), + ) + async def patch_registered_function_job( self, *, product_name: ProductName, user_id: UserID, - function_job_uuid: FunctionJobID, - registered_function_job_patch: RegisteredFunctionJobPatch, + function_job_patch_request: FunctionJobPatchRequest, ) -> RegisteredFunctionJob: """Patch a registered function job.""" return TypeAdapter(RegisteredFunctionJob).validate_python( @@ -343,8 +364,24 @@ async def patch_registered_function_job( "patch_registered_function_job", product_name=product_name, user_id=user_id, - function_job_uuid=function_job_uuid, - registered_function_job_patch=registered_function_job_patch, + function_job_patch_request=function_job_patch_request, + ), + ) + + async def batch_patch_registered_function_job( + self, + *, + product_name: ProductName, + user_id: UserID, + function_job_patch_requests: FunctionJobPatchRequestList, + ) -> BatchUpdateRegisteredFunctionJobs: + """Patch a registered function job.""" + return BatchUpdateRegisteredFunctionJobs.model_validate( + await self._request( + "batch_patch_registered_function_jobs", + product_name=product_name, + user_id=user_id, + function_job_patch_requests=function_job_patch_requests, ), ) @@ -462,16 +499,18 @@ async def find_cached_function_jobs( product_name: ProductName, user_id: UserID, function_id: FunctionID, - inputs: FunctionInputs, - ) -> list[RegisteredFunctionJob] | None: + inputs: FunctionInputsList, + cached_job_statuses: list[FunctionJobStatus] | None = None, + ) -> list[RegisteredFunctionJob | None]: """Find cached function jobs.""" - return TypeAdapter(list[RegisteredFunctionJob] | None).validate_python( + return TypeAdapter(list[RegisteredFunctionJob | None]).validate_python( await self._request( "find_cached_function_jobs", product_name=product_name, user_id=user_id, function_id=function_id, inputs=inputs, + cached_job_statuses=cached_job_statuses, ), ) diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py index e71ef46de027..e4a7e99317dd 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs.py @@ -1,42 +1,44 @@ from dataclasses import dataclass -from typing import overload import jsonschema from common_library.exclude import as_dict_exclude_none from models_library.functions import ( + BatchUpdateRegisteredFunctionJobs, FunctionClass, FunctionID, FunctionInputs, FunctionJobCollectionID, FunctionJobID, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, FunctionSchemaClass, ProjectFunctionJob, RegisteredFunction, RegisteredFunctionJob, - RegisteredFunctionJobPatch, RegisteredProjectFunctionJobPatch, RegisteredSolverFunctionJobPatch, SolverFunctionJob, - SolverJobID, - TaskID, ) from models_library.functions_errors import ( FunctionInputsValidationError, UnsupportedFunctionClassError, ) from models_library.products import ProductName -from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.rest_pagination import PageMetaInfoLimitOffset, PageOffsetInt from models_library.rpc_pagination import PageLimitInt from models_library.users import UserID -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from simcore_service_api_server._service_functions import FunctionService from simcore_service_api_server.services_rpc.storage import StorageService from ._service_jobs import JobService from .models.api_resources import JobLinks -from .models.domain.functions import PreRegisteredFunctionJobData +from .models.domain.functions import ( + FunctionJobPatch, + PreRegisteredFunctionJobData, +) from .models.schemas.jobs import JobInputs, JobPricingSpecification from .services_http.webserver import AuthSession from .services_rpc.wb_api_server import WbApiRpcClient @@ -90,14 +92,9 @@ async def list_function_jobs( **pagination_kwargs, ) - async def validate_function_inputs( - self, *, function_id: FunctionID, inputs: FunctionInputs + async def validate_function_inputs( # pylint: disable=no-self-use + self, *, function: RegisteredFunction, job_inputs: list[JobInputs] ) -> tuple[bool, str]: - function = await self._web_rpc_client.get_function( - function_id=function_id, - user_id=self.user_id, - product_name=self.product_name, - ) if ( function.input_schema is None @@ -107,9 +104,11 @@ async def validate_function_inputs( if function.input_schema.schema_class == FunctionSchemaClass.json_schema: try: - jsonschema.validate( - instance=inputs, schema=function.input_schema.schema_content - ) + for input_ in job_inputs: + jsonschema.validate( + instance=input_.values, + schema=function.input_schema.schema_content, + ) except ValidationError as err: return False, str(err) return True, "Inputs are valid" @@ -119,7 +118,7 @@ async def validate_function_inputs( f"Unsupported function schema class {function.input_schema.schema_class}", ) - async def create_function_job_inputs( # pylint: disable=no-self-use + def create_function_job_inputs( # pylint: disable=no-self-use self, *, function: RegisteredFunction, @@ -133,135 +132,127 @@ async def create_function_job_inputs( # pylint: disable=no-self-use values=joined_inputs or {}, ) - async def pre_register_function_job( + async def batch_pre_register_function_jobs( self, *, function: RegisteredFunction, - job_inputs: JobInputs, - ) -> PreRegisteredFunctionJobData: + job_input_list: list[JobInputs], + ) -> list[PreRegisteredFunctionJobData]: if function.input_schema is not None: is_valid, validation_str = await self.validate_function_inputs( - function_id=function.uid, - inputs=job_inputs.values, + function=function, + job_inputs=job_input_list, ) if not is_valid: raise FunctionInputsValidationError(error=validation_str) + function_jobs: list[ProjectFunctionJob | SolverFunctionJob] if function.function_class == FunctionClass.PROJECT: - job = await self._web_rpc_client.register_function_job( - function_job=ProjectFunctionJob( + function_jobs = [ + ProjectFunctionJob( function_uid=function.uid, title=f"Function job of function {function.uid}", description=function.description, - inputs=job_inputs.values, + inputs=input_.values, outputs=None, project_job_id=None, job_creation_task_id=None, - ), - user_id=self.user_id, - product_name=self.product_name, + ) + for input_ in job_input_list + ] + batch_registered_jobs = ( + await self._web_rpc_client.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python( + function_jobs + ), + user_id=self.user_id, + product_name=self.product_name, + ) ) + jobs = batch_registered_jobs.created_items elif function.function_class == FunctionClass.SOLVER: - job = await self._web_rpc_client.register_function_job( - function_job=SolverFunctionJob( + function_jobs = [ + SolverFunctionJob( function_uid=function.uid, title=f"Function job of function {function.uid}", description=function.description, - inputs=job_inputs.values, + inputs=input_.values, outputs=None, solver_job_id=None, job_creation_task_id=None, - ), - user_id=self.user_id, - product_name=self.product_name, + ) + for input_ in job_input_list + ] + batch_registered_jobs = ( + await self._web_rpc_client.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python( + function_jobs + ), + user_id=self.user_id, + product_name=self.product_name, + ) ) + jobs = batch_registered_jobs.created_items else: raise UnsupportedFunctionClassError( function_class=function.function_class, ) - return PreRegisteredFunctionJobData( - function_job_id=job.uid, - job_inputs=job_inputs, - ) - - @overload - async def patch_registered_function_job( - self, - *, - user_id: UserID, - product_name: ProductName, - function_job_id: FunctionJobID, - function_class: FunctionClass, - job_creation_task_id: TaskID | None, - ) -> RegisteredFunctionJob: ... - - @overload - async def patch_registered_function_job( - self, - *, - user_id: UserID, - product_name: ProductName, - function_job_id: FunctionJobID, - function_class: FunctionClass, - job_creation_task_id: TaskID | None, - project_job_id: ProjectID | None, - ) -> RegisteredFunctionJob: ... - - @overload - async def patch_registered_function_job( - self, - *, - user_id: UserID, - product_name: ProductName, - function_job_id: FunctionJobID, - function_class: FunctionClass, - job_creation_task_id: TaskID | None, - solver_job_id: SolverJobID | None, - ) -> RegisteredFunctionJob: ... + return [ + PreRegisteredFunctionJobData( + function_job_id=job.uid, + job_inputs=input_, + ) + for job, input_ in zip(jobs, job_input_list) + ] - async def patch_registered_function_job( + async def batch_patch_registered_function_job( self, *, user_id: UserID, product_name: ProductName, - function_job_id: FunctionJobID, - function_class: FunctionClass, - job_creation_task_id: TaskID | None, - project_job_id: ProjectID | None = None, - solver_job_id: SolverJobID | None = None, - ) -> RegisteredFunctionJob: - # Only allow one of project_job_id or solver_job_id depending on function_class - patch: RegisteredFunctionJobPatch - if function_class == FunctionClass.PROJECT: - patch = RegisteredProjectFunctionJobPatch( - title=None, - description=None, - inputs=None, - outputs=None, - job_creation_task_id=job_creation_task_id, - project_job_id=project_job_id, - ) - elif function_class == FunctionClass.SOLVER: - patch = RegisteredSolverFunctionJobPatch( - title=None, - description=None, - inputs=None, - outputs=None, - job_creation_task_id=job_creation_task_id, - solver_job_id=solver_job_id, - ) - else: - raise UnsupportedFunctionClassError( - function_class=function_class, - ) - return await self._web_rpc_client.patch_registered_function_job( + function_job_patches: list[FunctionJobPatch], + ) -> BatchUpdateRegisteredFunctionJobs: + patch_inputs: FunctionJobPatchRequestList = [] + for patch in function_job_patches: + if patch.function_class == FunctionClass.PROJECT: + patch_inputs.append( + FunctionJobPatchRequest( + uid=patch.function_job_id, + patch=RegisteredProjectFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=patch.job_creation_task_id, + project_job_id=None, + ), + ) + ) + elif patch.function_class == FunctionClass.SOLVER: + patch_inputs.append( + FunctionJobPatchRequest( + uid=patch.function_job_id, + patch=RegisteredSolverFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=patch.job_creation_task_id, + solver_job_id=None, + ), + ) + ) + else: + raise UnsupportedFunctionClassError( + function_class=patch.function_class, + ) + return await self._web_rpc_client.batch_patch_registered_function_job( user_id=user_id, product_name=product_name, - function_job_id=function_job_id, - registered_function_job_patch=patch, + function_job_patch_requests=patch_inputs, ) async def run_function( @@ -290,14 +281,22 @@ async def run_function( job_id=study_job.id, pricing_spec=pricing_spec, ) - return await self.patch_registered_function_job( + registered_job = await self._web_rpc_client.patch_registered_function_job( user_id=self.user_id, product_name=self.product_name, - function_job_id=pre_registered_function_job_data.function_job_id, - function_class=FunctionClass.PROJECT, - job_creation_task_id=None, - project_job_id=study_job.id, + function_job_patch_request=FunctionJobPatchRequest( + uid=pre_registered_function_job_data.function_job_id, + patch=RegisteredProjectFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=None, + project_job_id=study_job.id, + ), + ), ) + return registered_job if function.function_class == FunctionClass.SOLVER: solver_job = await self._job_service.create_solver_job( @@ -315,14 +314,22 @@ async def run_function( job_id=solver_job.id, pricing_spec=pricing_spec, ) - return await self.patch_registered_function_job( + registered_job = await self._web_rpc_client.patch_registered_function_job( user_id=self.user_id, product_name=self.product_name, - function_job_id=pre_registered_function_job_data.function_job_id, - function_class=FunctionClass.SOLVER, - job_creation_task_id=None, - solver_job_id=solver_job.id, + function_job_patch_request=FunctionJobPatchRequest( + uid=pre_registered_function_job_data.function_job_id, + patch=RegisteredSolverFunctionJobPatch( + title=None, + description=None, + inputs=None, + outputs=None, + job_creation_task_id=None, + solver_job_id=solver_job.id, + ), + ), ) + return registered_job raise UnsupportedFunctionClassError( function_class=function.function_class, diff --git a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py index 80a4af8763ac..79af528deeb5 100644 --- a/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py +++ b/services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py @@ -1,5 +1,4 @@ # pylint: disable=too-many-instance-attributes -import contextlib import logging from dataclasses import dataclass @@ -10,6 +9,7 @@ FunctionClass, FunctionID, FunctionInputs, + FunctionInputsList, FunctionJobCollectionID, FunctionJobID, FunctionJobStatus, @@ -30,13 +30,12 @@ from models_library.rest_pagination import PageMetaInfoLimitOffset, PageOffsetInt from models_library.rpc_pagination import PageLimitInt from models_library.users import UserID +from pydantic import TypeAdapter from servicelib.celery.models import ExecutionMetadata, TasksQueue, TaskUUID from servicelib.celery.task_manager import TaskManager -from simcore_service_api_server.models.schemas.functions import ( - FunctionJobCreationTaskStatus, -) from sqlalchemy.ext.asyncio import AsyncEngine +from ._meta import APP_NAME from ._service_function_jobs import FunctionJobService from ._service_functions import FunctionService from ._service_jobs import JobService @@ -45,9 +44,10 @@ SolverJobOutputRequestButNotSucceededError, StudyJobOutputRequestButNotSucceededError, ) -from .exceptions.function_errors import FunctionJobCacheNotFoundError from .models.api_resources import JobLinks from .models.domain.celery_models import ApiServerOwnerMetadata +from .models.domain.functions import FunctionJobPatch +from .models.schemas.functions import FunctionJobCreationTaskStatus from .models.schemas.jobs import JobInputs, JobPricingSpecification from .services_http.webserver import AuthSession from .services_rpc.storage import StorageService @@ -238,30 +238,7 @@ async def inspect_function_job( check_write_permissions=False, ) - async def get_cached_function_job( - self, - *, - function: RegisteredFunction, - job_inputs: JobInputs, - ) -> RegisteredFunctionJob: - """Raises FunctionJobCacheNotFoundError if no cached job is found""" - if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs( - function_id=function.uid, - inputs=job_inputs.values, - user_id=self.user_id, - product_name=self.product_name, - ): - for cached_function_job in cached_function_jobs: - job_status = await self.inspect_function_job( - function=function, - function_job=cached_function_job, - ) - if job_status.status == RunningState.SUCCESS: - return cached_function_job - - raise FunctionJobCacheNotFoundError - - async def function_job_outputs( # noqa: PLR0911 # pylint: disable=too-many-return-statements + async def function_job_outputs( # pylint: disable=too-many-return-statements self, *, function: RegisteredFunction, @@ -325,61 +302,83 @@ async def function_job_outputs( # noqa: PLR0911 # pylint: disable=too-many-retu check_write_permissions=False, ) - async def create_function_job_creation_task( + async def create_function_job_creation_tasks( self, *, function: RegisteredFunction, - function_inputs: FunctionInputs, + function_inputs: FunctionInputsList, user_identity: Identity, pricing_spec: JobPricingSpecification | None, job_links: JobLinks, parent_project_uuid: ProjectID | None = None, parent_node_id: NodeID | None = None, - ) -> RegisteredFunctionJob: - job_inputs = await self._function_job_service.create_function_job_inputs( - function=function, function_inputs=function_inputs + ) -> list[RegisteredFunctionJob]: + inputs = [ + join_inputs(function.default_inputs, input_) for input_ in function_inputs + ] + + cached_jobs = await self._web_rpc_client.find_cached_function_jobs( + user_id=user_identity.user_id, + product_name=user_identity.product_name, + function_id=function.uid, + inputs=TypeAdapter(FunctionInputsList).validate_python(inputs), + status_filter=[FunctionJobStatus(status=RunningState.SUCCESS)], ) - # check if results are cached - with contextlib.suppress(FunctionJobCacheNotFoundError): - return await self.get_cached_function_job( - function=function, - job_inputs=job_inputs, - ) + uncached_inputs = [ + input_ for input_, job in zip(inputs, cached_jobs) if job is None + ] - pre_registered_function_job_data = ( - await self._function_job_service.pre_register_function_job( + pre_registered_function_job_data_list = ( + await self._function_job_service.batch_pre_register_function_jobs( function=function, - job_inputs=job_inputs, + job_input_list=[JobInputs(values=_ or {}) for _ in uncached_inputs], ) ) # run function in celery task - owner_metadata = ApiServerOwnerMetadata( - user_id=user_identity.user_id, product_name=user_identity.product_name - ) - - task_uuid = await self._celery_task_manager.submit_task( - ExecutionMetadata( - name="run_function", - ephemeral=False, - queue=TasksQueue.API_WORKER_QUEUE, - ), - owner_metadata=owner_metadata, - user_identity=user_identity, - function=function, - pre_registered_function_job_data=pre_registered_function_job_data, - pricing_spec=pricing_spec, - job_links=job_links, - x_simcore_parent_project_uuid=parent_project_uuid, - x_simcore_parent_node_id=parent_node_id, + user_id=user_identity.user_id, + product_name=user_identity.product_name, + owner=APP_NAME, ) + task_uuids = [ + await self._celery_task_manager.submit_task( + ExecutionMetadata( + name="run_function", + ephemeral=False, + queue=TasksQueue.API_WORKER_QUEUE, + ), + owner_metadata=owner_metadata, + user_identity=user_identity, + function=function, + pre_registered_function_job_data=pre_registered_function_job_data, + pricing_spec=pricing_spec, + job_links=job_links, + x_simcore_parent_project_uuid=parent_project_uuid, + x_simcore_parent_node_id=parent_node_id, + ) + for pre_registered_function_job_data in pre_registered_function_job_data_list + ] - return await self._function_job_service.patch_registered_function_job( + patched_jobs = await self._function_job_service.batch_patch_registered_function_job( user_id=user_identity.user_id, product_name=user_identity.product_name, - function_job_id=pre_registered_function_job_data.function_job_id, - function_class=function.function_class, - job_creation_task_id=TaskID(task_uuid), + function_job_patches=[ + FunctionJobPatch( + function_class=function.function_class, + function_job_id=pre_registered_function_job_data.function_job_id, + job_creation_task_id=TaskID(task_uuid), + project_job_id=None, + solver_job_id=None, + ) + for task_uuid, pre_registered_function_job_data in zip( + task_uuids, pre_registered_function_job_data_list + ) + ], + ) + patched_jobs_iter = iter(patched_jobs.updated_items) + resolve_cached_jobs = lambda job: ( + job if job is not None else next(patched_jobs_iter) ) + return [resolve_cached_jobs(job) for job in cached_jobs] diff --git a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py index 8bedd5efb1e9..0ce21bfd5d9a 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/function_jobs_routes.py @@ -172,9 +172,10 @@ async def register_function_job( user_id: Annotated[UserID, Depends(get_current_user_id)], product_name: Annotated[ProductName, Depends(get_product_name)], ) -> RegisteredFunctionJob: - return await wb_api_rpc.register_function_job( + registered_job = await wb_api_rpc.register_function_job( function_job=function_job, user_id=user_id, product_name=product_name ) + return registered_job @function_job_router.get( diff --git a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py index f009a0079f57..878ceef4fe20 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py @@ -16,20 +16,19 @@ RegisteredFunctionJob, RegisteredFunctionJobCollection, ) -from models_library.functions import FunctionJobCollection, FunctionJobID +from models_library.functions import FunctionJobCollection from models_library.products import ProductName from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID from models_library.users import UserID from servicelib.fastapi.dependencies import get_reverse_url_mapper -from servicelib.utils import limited_gather from ..._service_function_jobs import FunctionJobService from ..._service_function_jobs_task_client import FunctionJobTaskClientService from ..._service_functions import FunctionService from ...models.pagination import Page, PaginationParams from ...models.schemas.errors import ErrorGet -from ...models.schemas.jobs import JobPricingSpecification +from ...models.schemas.jobs import JobInputs, JobPricingSpecification from ...services_rpc.wb_api_server import WbApiRpcClient from ..dependencies.authentication import ( Identity, @@ -297,15 +296,16 @@ async def get_function_outputschema( ), ) async def validate_function_inputs( - function_id: FunctionID, + function_id: FunctionID, # pylint: disable=unused-argument inputs: FunctionInputs, + function: Annotated[RegisteredFunction, Depends(get_function)], function_job_service: Annotated[ FunctionJobService, Depends(get_function_job_service) ], ) -> tuple[bool, str]: return await function_job_service.validate_function_inputs( - function_id=function_id, - inputs=inputs, + function=function, + job_inputs=[JobInputs(values=inputs or {})], ) @@ -349,15 +349,17 @@ async def run_function( ) job_links = await function_service.get_function_job_links(to_run_function, url_for) - return await function_job_task_client_service.create_function_job_creation_task( + jobs = await function_job_task_client_service.create_function_job_creation_tasks( function=to_run_function, - function_inputs=function_inputs, + function_inputs=[function_inputs], user_identity=user_identity, pricing_spec=pricing_spec, job_links=job_links, parent_project_uuid=parent_project_uuid, parent_node_id=parent_node_id, ) + assert len(jobs) == 1 # nosec + return jobs[0] @function_router.delete( @@ -428,42 +430,23 @@ async def map_function( ) job_links = await function_service.get_function_job_links(to_run_function, url_for) - async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID: - result = ( - await function_job_task_client_service.create_function_job_creation_task( - function=to_run_function, - function_inputs=function_inputs, - user_identity=user_identity, - pricing_spec=pricing_spec, - job_links=job_links, - parent_project_uuid=parent_project_uuid, - parent_node_id=parent_node_id, - ) - ) - return result.uid - - # Run all tasks concurrently, allowing them to complete even if some fail - results = await limited_gather( - *[ - _run_single_function(function_inputs) - for function_inputs in function_inputs_list - ], - reraise=False, - limit=1, + jobs = await function_job_task_client_service.create_function_job_creation_tasks( + function=to_run_function, + function_inputs=function_inputs_list, + user_identity=user_identity, + pricing_spec=pricing_spec, + job_links=job_links, + parent_project_uuid=parent_project_uuid, + parent_node_id=parent_node_id, ) - # Check if any tasks raised exceptions and raise the first one found - for result in results: - if isinstance(result, BaseException): - raise result - # At this point, all results are FunctionJobID since we've checked for exceptions function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs" return await web_api_rpc_client.register_function_job_collection( function_job_collection=FunctionJobCollection( title="Function job collection of function map", description=function_job_collection_description, - job_ids=results, # type: ignore + job_ids=[job.uid for job in jobs], ), user_id=user_identity.user_id, product_name=user_identity.product_name, diff --git a/services/api-server/src/simcore_service_api_server/main.py b/services/api-server/src/simcore_service_api_server/main.py index ace337b472e1..0dfca9daa02a 100644 --- a/services/api-server/src/simcore_service_api_server/main.py +++ b/services/api-server/src/simcore_service_api_server/main.py @@ -7,11 +7,10 @@ from fastapi import FastAPI from servicelib.fastapi.logging_lifespan import create_logging_shutdown_event from servicelib.tracing import TracingConfig +from simcore_service_api_server._meta import APP_NAME from simcore_service_api_server.core.application import create_app from simcore_service_api_server.core.settings import ApplicationSettings -from ._meta import APP_NAME - _logger = logging.getLogger(__name__) _NOISY_LOGGERS: Final[tuple[str, ...]] = ( diff --git a/services/api-server/src/simcore_service_api_server/models/domain/functions.py b/services/api-server/src/simcore_service_api_server/models/domain/functions.py index 1f75f7441f29..1fcb734cd6ae 100644 --- a/services/api-server/src/simcore_service_api_server/models/domain/functions.py +++ b/services/api-server/src/simcore_service_api_server/models/domain/functions.py @@ -1,9 +1,13 @@ from models_library.functions import ( + FunctionClass, FunctionJobID, RegisteredFunctionJob, RegisteredFunctionJobWithStatus, + SolverJobID, + TaskID, ) -from pydantic import BaseModel +from models_library.projects import ProjectID +from pydantic import BaseModel, model_validator from ...models.pagination import Page from ...models.schemas.jobs import JobInputs @@ -20,3 +24,30 @@ class PageRegisteredFunctionJobWithorWithoutStatus( # This class is created specifically to provide a name for this in openapi.json. # When using an alias the python-client generates too long file name pass + + +class FunctionJobPatch(BaseModel): + function_class: FunctionClass + function_job_id: FunctionJobID + job_creation_task_id: TaskID | None = None + project_job_id: ProjectID | None = None + solver_job_id: SolverJobID | None = None + + @model_validator(mode="after") + def validate_function_class_consistency(self) -> "FunctionJobPatch": + """Validate consistency between function_class and job IDs.""" + if ( + self.solver_job_id is not None + and self.function_class != FunctionClass.SOLVER + ): + msg = f"solver_job_id must be None when function_class is {self.function_class}, expected {FunctionClass.SOLVER}" + raise ValueError(msg) + + if ( + self.project_job_id is not None + and self.function_class != FunctionClass.PROJECT + ): + msg = f"project_job_id must be None when function_class is {self.function_class}, expected {FunctionClass.PROJECT}" + raise ValueError(msg) + + return self diff --git a/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py b/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py index 006419c238b4..16b4262ebaeb 100644 --- a/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py +++ b/services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py @@ -12,7 +12,6 @@ FunctionID, FunctionInputs, FunctionInputSchema, - FunctionJob, FunctionJobCollection, FunctionJobCollectionID, FunctionJobCollectionsListFilters, @@ -24,11 +23,17 @@ ) from models_library.api_schemas_webserver.licensed_items import LicensedItemRpcGetPage from models_library.functions import ( + BatchCreateRegisteredFunctionJobs, + BatchUpdateRegisteredFunctionJobs, + FunctionInputsList, + FunctionJob, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, FunctionJobStatus, FunctionOutputs, FunctionUserAccessRights, FunctionUserApiAccessRights, - RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatus, ) from models_library.licenses import LicensedItemID @@ -481,7 +486,11 @@ async def delete_function_job( ) async def register_function_job( - self, *, user_id: UserID, function_job: FunctionJob, product_name: ProductName + self, + *, + user_id: UserID, + function_job: FunctionJob, + product_name: ProductName, ) -> RegisteredFunctionJob: return await self._rpc_client.functions.register_function_job( user_id=user_id, @@ -489,19 +498,43 @@ async def register_function_job( function_job=function_job, ) + async def batch_register_function_jobs( + self, + *, + user_id: UserID, + function_jobs: FunctionJobList, + product_name: ProductName, + ) -> BatchCreateRegisteredFunctionJobs: + return await self._rpc_client.functions.batch_register_function_jobs( + user_id=user_id, + product_name=product_name, + function_jobs=function_jobs, + ) + async def patch_registered_function_job( self, *, user_id: UserID, product_name: ProductName, - function_job_id: FunctionJobID, - registered_function_job_patch: RegisteredFunctionJobPatch, + function_job_patch_request: FunctionJobPatchRequest, ) -> RegisteredFunctionJob: return await self._rpc_client.functions.patch_registered_function_job( user_id=user_id, product_name=product_name, - function_job_uuid=function_job_id, - registered_function_job_patch=registered_function_job_patch, + function_job_patch_request=function_job_patch_request, + ) + + async def batch_patch_registered_function_job( + self, + *, + product_name: ProductName, + user_id: UserID, + function_job_patch_requests: FunctionJobPatchRequestList, + ) -> BatchUpdateRegisteredFunctionJobs: + return await self._rpc_client.functions.batch_patch_registered_function_job( + product_name=product_name, + user_id=user_id, + function_job_patch_requests=function_job_patch_requests, ) async def get_function_input_schema( @@ -588,13 +621,15 @@ async def find_cached_function_jobs( user_id: UserID, product_name: ProductName, function_id: FunctionID, - inputs: FunctionInputs, - ) -> list[RegisteredFunctionJob] | None: + inputs: FunctionInputsList, + status_filter: list[FunctionJobStatus] | None, + ) -> list[RegisteredFunctionJob | None]: return await self._rpc_client.functions.find_cached_function_jobs( user_id=user_id, product_name=product_name, function_id=function_id, inputs=inputs, + cached_job_statuses=status_filter, ) async def get_function_job_collection( diff --git a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py index 7d5d197d1774..4659668812ba 100644 --- a/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py +++ b/services/api-server/tests/unit/api_functions/celery/test_functions_celery.py @@ -24,10 +24,17 @@ from httpx import AsyncClient, BasicAuth, HTTPStatusError from models_library.api_schemas_long_running_tasks.tasks import TaskResult, TaskStatus from models_library.functions import ( + BatchCreateRegisteredFunctionJobs, + BatchUpdateRegisteredFunctionJobs, FunctionClass, FunctionID, + FunctionInputsList, FunctionJobCollection, FunctionJobID, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, + FunctionJobStatus, FunctionUserAccessRights, FunctionUserApiAccessRights, RegisteredFunction, @@ -36,7 +43,9 @@ RegisteredProjectFunction, RegisteredProjectFunctionJob, RegisteredProjectFunctionJobPatch, + RegisteredSolverFunctionJobPatch, ) +from models_library.products import ProductName from models_library.projects import ProjectID from models_library.users import UserID from pytest_mock import MockType @@ -150,15 +159,72 @@ def _(celery_app: Celery) -> None: return _ -async def _patch_registered_function_job_side_effect( - mock_registered_project_function_job: RegisteredFunctionJob, *args, **kwargs +async def _batch_patch_registered_function_job( + mock_registered_project_function_job: RegisteredFunctionJob, + product_name: ProductName, + user_id: UserID, + function_job_patch_requests: FunctionJobPatchRequestList, +): + jobs = [] + for patch_request in function_job_patch_requests: + patch = patch_request.patch + assert isinstance( + patch, (RegisteredProjectFunctionJobPatch, RegisteredSolverFunctionJobPatch) + ) + jobs.append( + mock_registered_project_function_job.model_copy( + update={ + "job_creation_task_id": patch.job_creation_task_id, + "uid": patch_request.uid, + } + ) + ) + return BatchUpdateRegisteredFunctionJobs(updated_items=jobs) + + +async def _patch_registered_function_job( + mock_registered_project_function_job: RegisteredFunctionJob, + product_name: ProductName, + user_id: UserID, + function_job_patch_request: FunctionJobPatchRequest, +): + patch = function_job_patch_request.patch + assert isinstance( + patch, (RegisteredProjectFunctionJobPatch, RegisteredSolverFunctionJobPatch) + ) + job = mock_registered_project_function_job.model_copy( + update={ + "job_creation_task_id": patch.job_creation_task_id, + "uid": function_job_patch_request.uid, + } + ) + return job + + +async def _find_cached_function_jobs_side_effect( + *, + user_id: UserID, + product_name: ProductName, + function_id: FunctionID, + inputs: FunctionInputsList, + status_filter: list[FunctionJobStatus] | None, +): + return [None] * len(inputs) + + +async def _batch_register_function_jobs( + registered_function_job: RegisteredFunctionJob, + user_id: UserID, + function_jobs: FunctionJobList, + product_name: ProductName, ): - registered_function_job_patch = kwargs["registered_function_job_patch"] - assert isinstance(registered_function_job_patch, RegisteredProjectFunctionJobPatch) - job_creation_task_id = registered_function_job_patch.job_creation_task_id - uid = kwargs["function_job_uuid"] - return mock_registered_project_function_job.model_copy( - update={"job_creation_task_id": job_creation_task_id, "uid": uid} + return BatchCreateRegisteredFunctionJobs( + created_items=[ + registered_function_job.model_copy( + update={"uid": FunctionJobID(_faker.uuid4())} + ) + for _ in function_jobs + ] ) @@ -216,16 +282,20 @@ async def test_with_fake_run_function( "get_function", return_value=fake_registered_project_function ) mock_handler_in_functions_rpc_interface( - "find_cached_function_jobs", return_value=[] + "find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect ) mock_handler_in_functions_rpc_interface( - "register_function_job", return_value=fake_registered_project_function_job + "batch_register_function_jobs", + side_effect=partial( + _batch_register_function_jobs, + fake_registered_project_function_job, + ), ) mock_handler_in_functions_rpc_interface( - "patch_registered_function_job", + "batch_patch_registered_function_job", side_effect=partial( - _patch_registered_function_job_side_effect, + _batch_patch_registered_function_job, fake_registered_project_function_job, ), ) @@ -377,10 +447,14 @@ def _default_side_effect( "get_function", return_value=fake_registered_project_function ) mock_handler_in_functions_rpc_interface( - "find_cached_function_jobs", return_value=[] + "find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect ) mock_handler_in_functions_rpc_interface( - "register_function_job", return_value=fake_registered_project_function_job + "batch_register_function_jobs", + side_effect=partial( + _batch_register_function_jobs, + fake_registered_project_function_job, + ), ) mock_handler_in_functions_rpc_interface( "get_functions_user_api_access_rights", @@ -391,10 +465,17 @@ def _default_side_effect( read_functions=True, ), ) + mock_handler_in_functions_rpc_interface( + "batch_patch_registered_function_job", + side_effect=partial( + _batch_patch_registered_function_job, + fake_registered_project_function_job, + ), + ) mock_handler_in_functions_rpc_interface( "patch_registered_function_job", side_effect=partial( - _patch_registered_function_job_side_effect, + _patch_registered_function_job, fake_registered_project_function_job, ), ) @@ -497,10 +578,14 @@ def _default_side_effect( "get_function", return_value=fake_registered_project_function ) mock_handler_in_functions_rpc_interface( - "find_cached_function_jobs", return_value=[] + "find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect ) mock_handler_in_functions_rpc_interface( - "register_function_job", return_value=fake_registered_project_function_job + "batch_register_function_jobs", + side_effect=partial( + _batch_register_function_jobs, + fake_registered_project_function_job, + ), ) mock_handler_in_functions_rpc_interface( "get_functions_user_api_access_rights", @@ -523,13 +608,19 @@ def _default_side_effect( ) patch_mock = mock_handler_in_functions_rpc_interface( + "batch_patch_registered_function_job", + side_effect=partial( + _batch_patch_registered_function_job, + fake_registered_project_function_job, + ), + ) + mock_handler_in_functions_rpc_interface( "patch_registered_function_job", side_effect=partial( - _patch_registered_function_job_side_effect, + _patch_registered_function_job, fake_registered_project_function_job, ), ) - mock_handler_in_projects_rpc_interface("mark_project_as_job", return_value=None) # ACT @@ -551,9 +642,10 @@ def _default_side_effect( if expected_status_code == status.HTTP_200_OK: FunctionJobCollection.model_validate(response.json()) - task_id = patch_mock.call_args.kwargs[ - "registered_function_job_patch" - ].job_creation_task_id + task_id = patch_mock.call_args.kwargs["function_job_patch_requests"][ + 0 + ].patch.job_creation_task_id + assert task_id is not None await _wait_for_task_result(client, auth, f"{task_id}") assert side_effect_checks["headers_checked"] is True @@ -604,24 +696,32 @@ def _default_side_effect( "get_function", return_value=fake_registered_project_function ) mock_handler_in_functions_rpc_interface( - "find_cached_function_jobs", return_value=[] + "find_cached_function_jobs", side_effect=_find_cached_function_jobs_side_effect ) mock_handler_in_projects_rpc_interface("mark_project_as_job", return_value=None) _generated_function_job_ids: list[FunctionJobID] = [] - async def _register_function_job_side_effect( - generated_function_job_ids: list[FunctionJobID], *args, **kwargs + async def _batch_register_function_jobs_side_effect( + generated_function_job_ids: list, + user_id: UserID, + function_jobs: FunctionJobList, + product_name: ProductName, ): - uid = FunctionJobID(_faker.uuid4()) - generated_function_job_ids.append(uid) - return fake_registered_project_function_job.model_copy(update={"uid": uid}) + registered_jobs = [] + for _ in function_jobs: + uid = FunctionJobID(_faker.uuid4()) + generated_function_job_ids.append(uid) + registered_jobs.append( + fake_registered_project_function_job.model_copy(update={"uid": uid}) + ) + return BatchCreateRegisteredFunctionJobs(created_items=registered_jobs) mock_handler_in_functions_rpc_interface( - "register_function_job", + "batch_register_function_jobs", side_effect=partial( - _register_function_job_side_effect, _generated_function_job_ids + _batch_register_function_jobs_side_effect, _generated_function_job_ids ), ) mock_handler_in_functions_rpc_interface( @@ -650,9 +750,17 @@ async def _register_function_job_collection_side_effect(*args, **kwargs): ) patch_mock = mock_handler_in_functions_rpc_interface( + "batch_patch_registered_function_job", + side_effect=partial( + _batch_patch_registered_function_job, + fake_registered_project_function_job, + ), + ) + + mock_handler_in_functions_rpc_interface( "patch_registered_function_job", side_effect=partial( - _patch_registered_function_job_side_effect, + _patch_registered_function_job, fake_registered_project_function_job, ), ) @@ -675,10 +783,17 @@ async def _register_function_job_collection_side_effect(*args, **kwargs): assert ( job_collection.job_ids == _generated_function_job_ids ), "Job ID did not preserve order or were incorrectly propagated" - celery_task_ids = { - elm.kwargs["registered_function_job_patch"].job_creation_task_id - for elm in patch_mock.call_args_list - } + + celery_task_ids = set() + for args in patch_mock.call_args_list: + inputs = args.kwargs["function_job_patch_requests"] + celery_task_ids = celery_task_ids.union( + { + input_.patch.job_creation_task_id + for input_ in inputs + if input_.patch.job_creation_task_id + } + ) assert len(celery_task_ids) == len(_inputs) for task_id in celery_task_ids: await _wait_for_task_result(client, auth, f"{task_id}") diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py index e3dda6a18ef8..8b4dcea44ea0 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_function_jobs.py @@ -20,6 +20,7 @@ RegisteredProjectFunctionJob, ) from models_library.functions import ( + FunctionJob, FunctionJobStatus, RegisteredProjectFunction, RegisteredProjectFunctionJobWithStatus, @@ -65,15 +66,22 @@ async def test_delete_function_job( async def test_register_function_job( client: AsyncClient, - mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], + mock_handler_in_functions_rpc_interface: Callable, fake_project_function_job: ProjectFunctionJob, fake_registered_project_function_job: RegisteredProjectFunctionJob, auth: httpx.BasicAuth, ) -> None: """Test the register_function_job endpoint.""" + async def _register_function_job_side_effect( + user_id: UserID, + function_job: FunctionJob, + product_name: ProductName, + ): + return fake_registered_project_function_job + mock_handler_in_functions_rpc_interface( - "register_function_job", fake_registered_project_function_job + "register_function_job", side_effect=_register_function_job_side_effect ) response = await client.post( diff --git a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py index a3ae47a0fcd2..52d228debb93 100644 --- a/services/api-server/tests/unit/api_functions/test_api_routers_functions.py +++ b/services/api-server/tests/unit/api_functions/test_api_routers_functions.py @@ -19,6 +19,8 @@ from httpx import AsyncClient from models_library.api_schemas_long_running_tasks.tasks import TaskGet from models_library.functions import ( + FunctionJobList, + FunctionJobPatchRequest, FunctionUserAccessRights, FunctionUserApiAccessRights, ProjectFunction, @@ -26,11 +28,14 @@ RegisteredFunctionJob, RegisteredProjectFunction, RegisteredProjectFunctionJob, + RegisteredProjectFunctionJobPatch, + RegisteredSolverFunctionJobPatch, ) from models_library.functions_errors import ( FunctionIDNotFoundError, FunctionReadAccessDeniedError, ) +from models_library.products import ProductName from models_library.rest_pagination import PageMetaInfoLimitOffset from models_library.users import UserID from pydantic import EmailStr @@ -407,7 +412,7 @@ async def test_run_project_function( mocked_webserver_rpc_api: dict[str, MockType], app: FastAPI, client: AsyncClient, - mock_handler_in_functions_rpc_interface: Callable[[str, Any], None], + mock_handler_in_functions_rpc_interface: Callable, fake_registered_project_function: RegisteredProjectFunction, fake_registered_project_function_job: RegisteredFunctionJob, auth: httpx.BasicAuth, @@ -471,8 +476,16 @@ def _default_side_effect( "get_function", fake_registered_project_function ) mock_handler_in_functions_rpc_interface("find_cached_function_jobs", []) + + async def _register_function_job_side_effect( + user_id: UserID, + function_jobs: FunctionJobList, + product_name: ProductName, + ): + return [fake_registered_project_function_job] * len(function_jobs) + mock_handler_in_functions_rpc_interface( - "register_function_job", fake_registered_project_function_job + "register_function_job", side_effect=_register_function_job_side_effect ) mock_handler_in_functions_rpc_interface( "get_functions_user_api_access_rights", @@ -483,8 +496,27 @@ def _default_side_effect( read_functions=True, ), ) + + async def _patch_registered_function_job( + product_name: ProductName, + user_id: UserID, + function_job_patch_request: FunctionJobPatchRequest, + ): + patch = function_job_patch_request.patch + assert isinstance( + patch, (RegisteredProjectFunctionJobPatch, RegisteredSolverFunctionJobPatch) + ) + job = fake_registered_project_function_job.model_copy( + update={ + "job_creation_task_id": patch.job_creation_task_id, + "uid": function_job_patch_request.uid, + } + ) + return job + mock_handler_in_functions_rpc_interface( - "patch_registered_function_job", fake_registered_project_function_job + "patch_registered_function_job", + side_effect=_patch_registered_function_job, ) pre_registered_function_job_data = PreRegisteredFunctionJobData( diff --git a/services/dynamic-scheduler/src/simcore_service_dynamic_scheduler/main.py b/services/dynamic-scheduler/src/simcore_service_dynamic_scheduler/main.py index a08411cc2c5c..786de55895e0 100644 --- a/services/dynamic-scheduler/src/simcore_service_dynamic_scheduler/main.py +++ b/services/dynamic-scheduler/src/simcore_service_dynamic_scheduler/main.py @@ -9,11 +9,10 @@ create_logging_lifespan, ) from servicelib.tracing import TracingConfig +from simcore_service_dynamic_scheduler._meta import APP_NAME from simcore_service_dynamic_scheduler.core.application import create_app from simcore_service_dynamic_scheduler.core.settings import ApplicationSettings -from ._meta import APP_NAME - _logger = logging.getLogger(__name__) _NOISY_LOGGERS: Final[tuple[str, ...]] = ( diff --git a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py index e5abbd51f39e..bfdc9528c19c 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py +++ b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py @@ -2,18 +2,23 @@ from aiohttp import web from models_library.functions import ( + BatchCreateRegisteredFunctionJobs, + BatchUpdateRegisteredFunctionJobs, Function, FunctionAccessRights, FunctionClass, FunctionGroupAccessRights, FunctionID, - FunctionInputs, FunctionInputSchema, + FunctionInputsList, FunctionJob, FunctionJobCollection, FunctionJobCollectionID, FunctionJobCollectionsListFilters, FunctionJobID, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, FunctionJobStatus, FunctionOutputs, FunctionOutputSchema, @@ -22,7 +27,6 @@ RegisteredFunction, RegisteredFunctionJob, RegisteredFunctionJobCollection, - RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatus, ) from models_library.functions_errors import ( @@ -100,6 +104,24 @@ async def register_function_job( ) +@router.expose( + reraise_if_error_type=( + UnsupportedFunctionJobClassError, + FunctionJobsWriteApiAccessDeniedError, + ) +) +async def batch_register_function_jobs( + app: web.Application, + *, + user_id: UserID, + product_name: ProductName, + function_jobs: FunctionJobList, +) -> BatchCreateRegisteredFunctionJobs: + return await _functions_service.batch_register_function_jobs( + app=app, user_id=user_id, product_name=product_name, function_jobs=function_jobs + ) + + @router.expose( reraise_if_error_type=( UnsupportedFunctionJobClassError, @@ -112,16 +134,37 @@ async def patch_registered_function_job( *, user_id: UserID, product_name: ProductName, - function_job_uuid: FunctionJobID, - registered_function_job_patch: RegisteredFunctionJobPatch, + function_job_patch_request: FunctionJobPatchRequest, ) -> RegisteredFunctionJob: return await _functions_service.patch_registered_function_job( app=app, user_id=user_id, product_name=product_name, - function_job_uuid=function_job_uuid, - registered_function_job_patch=registered_function_job_patch, + function_job_patch_request=function_job_patch_request, + ) + + +@router.expose( + reraise_if_error_type=( + UnsupportedFunctionJobClassError, + FunctionJobsWriteApiAccessDeniedError, + FunctionJobPatchModelIncompatibleError, + ) +) +async def batch_patch_registered_function_jobs( + app: web.Application, + *, + user_id: UserID, + product_name: ProductName, + function_job_patch_requests: FunctionJobPatchRequestList, +) -> BatchUpdateRegisteredFunctionJobs: + + return await _functions_service.batch_patch_registered_function_jobs( + app=app, + user_id=user_id, + product_name=product_name, + function_job_patch_requests=function_job_patch_requests, ) @@ -445,15 +488,19 @@ async def find_cached_function_jobs( user_id: UserID, product_name: ProductName, function_id: FunctionID, - inputs: FunctionInputs, -) -> list[RegisteredFunctionJob] | None: - return await _functions_service.find_cached_function_jobs( + inputs: FunctionInputsList, + cached_job_statuses: list[FunctionJobStatus] | None, +) -> list[RegisteredFunctionJob | None]: + retrieved_cached_function_jobs = await _functions_service.find_cached_function_jobs( app=app, user_id=user_id, product_name=product_name, function_id=function_id, inputs=inputs, + cached_job_statuses=cached_job_statuses, ) + assert len(retrieved_cached_function_jobs) == len(inputs) # nosec + return retrieved_cached_function_jobs @router.expose(reraise_if_error_type=(FunctionIDNotFoundError,)) diff --git a/services/web/server/src/simcore_service_webserver/functions/_function_jobs_repository.py b/services/web/server/src/simcore_service_webserver/functions/_function_jobs_repository.py index 7aeff55cba42..155a344efcd2 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_function_jobs_repository.py +++ b/services/web/server/src/simcore_service_webserver/functions/_function_jobs_repository.py @@ -1,25 +1,32 @@ # pylint: disable=too-many-arguments import json +import logging import sqlalchemy from aiohttp import web from models_library.functions import ( + BatchCreateRegisteredFunctionJobsDB, + BatchUpdateRegisteredFunctionJobsDB, FunctionClass, + FunctionClassSpecificData, FunctionID, - FunctionInputs, - FunctionJobClassSpecificData, + FunctionInputsList, FunctionJobCollectionID, + FunctionJobDB, FunctionJobID, + FunctionJobPatchRequest, FunctionJobStatus, FunctionOutputs, FunctionsApiAccessRights, RegisteredFunctionJobDB, + RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatusDB, ) from models_library.functions_errors import ( FunctionJobIDNotFoundError, - FunctionJobReadAccessDeniedError, + FunctionJobPatchModelIncompatibleError, + UnsupportedFunctionJobClassError, ) from models_library.products import ProductName from models_library.rest_pagination import PageMetaInfoLimitOffset @@ -38,9 +45,8 @@ pass_or_acquire_connection, transaction_context, ) -from sqlalchemy import Text, cast +from sqlalchemy import Text, cast, func from sqlalchemy.ext.asyncio import AsyncConnection -from sqlalchemy.sql import func from ..db.plugin import get_asyncpg_engine from ..groups.api import list_all_user_groups_ids @@ -54,21 +60,17 @@ _FUNCTION_JOBS_TABLE_COLS, ) +_logger = logging.getLogger(__name__) -async def create_function_job( # noqa: PLR0913 + +async def create_function_jobs( # noqa: PLR0913 app: web.Application, connection: AsyncConnection | None = None, *, user_id: UserID, product_name: ProductName, - function_class: FunctionClass, - function_uid: FunctionID, - title: str, - description: str, - inputs: FunctionInputs, - outputs: FunctionOutputs, - class_specific_data: FunctionJobClassSpecificData, -) -> RegisteredFunctionJobDB: + function_jobs: list[FunctionJobDB], +) -> BatchCreateRegisteredFunctionJobsDB: async with transaction_context(get_asyncpg_engine(app), connection) as transaction: await check_user_api_access_rights( app, @@ -79,50 +81,63 @@ async def create_function_job( # noqa: PLR0913 FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, ], ) + + # Prepare values for batch insert + values_to_insert = [ + { + "function_uuid": job.function_uuid, + "inputs": job.inputs, + "outputs": job.outputs, + "function_class": job.function_class, + "class_specific_data": job.class_specific_data, + "title": job.title, + "description": job.description, + "status": "created", + } + for job in function_jobs + ] + + # Batch insert all function jobs in a single query result = await transaction.execute( function_jobs_table.insert() - .values( - function_uuid=function_uid, - inputs=inputs, - outputs=outputs, - function_class=function_class, - class_specific_data=class_specific_data, - title=title, - description=description, - status="created", - ) + .values(values_to_insert) .returning(*_FUNCTION_JOBS_TABLE_COLS) ) - row = result.one() - registered_function_job = RegisteredFunctionJobDB.model_validate(row) + # Get all created jobs + created_jobs = TypeAdapter(list[RegisteredFunctionJobDB]).validate_python( + list(result) + ) + # Get user primary group and set permissions for all jobs user_primary_group_id = await users_service.get_user_primary_group_id( app, user_id=user_id ) + job_uuids = [job.uuid for job in created_jobs] + await _internal_set_group_permissions( app, connection=transaction, permission_group_id=user_primary_group_id, product_name=product_name, object_type="function_job", - object_ids=[registered_function_job.uuid], + object_ids=job_uuids, read=True, write=True, execute=True, ) - return registered_function_job + return BatchCreateRegisteredFunctionJobsDB(created_items=created_jobs) -async def patch_function_job( +async def patch_function_jobs( app: web.Application, connection: AsyncConnection | None = None, *, user_id: UserID, product_name: ProductName, - registered_function_job_db: RegisteredFunctionJobDB, -) -> RegisteredFunctionJobDB: + function_job_patch_requests: list[FunctionJobPatchRequest], +) -> BatchUpdateRegisteredFunctionJobsDB: async with transaction_context(get_asyncpg_engine(app), connection) as transaction: await check_user_api_access_rights( @@ -134,23 +149,41 @@ async def patch_function_job( FunctionsApiAccessRights.WRITE_FUNCTION_JOBS, ], ) - result = await transaction.execute( - function_jobs_table.update() - .where(function_jobs_table.c.uuid == f"{registered_function_job_db.uuid}") - .values( - inputs=registered_function_job_db.inputs, - outputs=registered_function_job_db.outputs, - function_class=registered_function_job_db.function_class, - class_specific_data=registered_function_job_db.class_specific_data, - title=registered_function_job_db.title, - description=registered_function_job_db.description, - status="created", + updated_jobs = [] + for patch_request in function_job_patch_requests: + + job = await get_function_job( + app, + connection=transaction, + user_id=user_id, + product_name=product_name, + function_job_id=patch_request.uid, ) - .returning(*_FUNCTION_JOBS_TABLE_COLS) - ) - row = result.one() + if job.function_class != patch_request.patch.function_class: + raise FunctionJobPatchModelIncompatibleError( + function_id=job.function_uuid, product_name=product_name + ) - return RegisteredFunctionJobDB.model_validate(row) + class_specific_data = _update_class_specific_data( + class_specific_data=job.class_specific_data, patch=patch_request.patch + ) + update_values = { + "inputs": patch_request.patch.inputs, + "outputs": patch_request.patch.outputs, + "class_specific_data": class_specific_data, + "title": patch_request.patch.title, + "description": patch_request.patch.description, + } + + result = await transaction.execute( + function_jobs_table.update() + .where(function_jobs_table.c.uuid == f"{patch_request.uid}") + .values(**{k: v for k, v in update_values.items() if v is not None}) + .returning(*_FUNCTION_JOBS_TABLE_COLS) + ) + updated_jobs.append(RegisteredFunctionJobDB.model_validate(result.one())) + + return BatchUpdateRegisteredFunctionJobsDB(updated_items=updated_jobs) async def list_function_jobs_with_status( @@ -291,36 +324,60 @@ async def find_cached_function_jobs( user_id: UserID, function_id: FunctionID, product_name: ProductName, - inputs: FunctionInputs, -) -> list[RegisteredFunctionJobDB] | None: + inputs: FunctionInputsList, + cached_job_statuses: list[FunctionJobStatus] | None = None, +) -> list[RegisteredFunctionJobDB | None]: async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: - jobs: list[RegisteredFunctionJobDB] = [] - async for row in await conn.stream( - function_jobs_table.select().where( - function_jobs_table.c.function_uuid == function_id, - cast(function_jobs_table.c.inputs, Text) == json.dumps(inputs), + # Get user groups for access check + user_groups = await list_all_user_groups_ids(app, user_id=user_id) + + # Create access subquery + access_subquery = ( + function_jobs_access_rights_table.select() + .with_only_columns(function_jobs_access_rights_table.c.function_job_uuid) + .where( + function_jobs_access_rights_table.c.group_id.in_(user_groups), + function_jobs_access_rights_table.c.product_name == product_name, + function_jobs_access_rights_table.c.read, ) - ): - job = RegisteredFunctionJobDB.model_validate(row) - try: - await check_user_permissions( - app, - connection=conn, - user_id=user_id, - product_name=product_name, - object_id=job.uuid, - object_type="function_job", - permissions=["read"], + ) + + # Create list of JSON dumped inputs for comparison + json_inputs = [json.dumps(inp) for inp in inputs] + + # Build filter conditions + filter_conditions = sqlalchemy.and_( + function_jobs_table.c.function_uuid == function_id, + cast(function_jobs_table.c.inputs, Text).in_(json_inputs), + function_jobs_table.c.uuid.in_(access_subquery), + ( + function_jobs_table.c.status.in_( + [status.status for status in cached_job_statuses] ) - except FunctionJobReadAccessDeniedError: - continue + if cached_job_statuses is not None + else sqlalchemy.sql.true() + ), + ) - jobs.append(job) + # Use DISTINCT ON to get only one job per input (the most recent one) + results = await conn.execute( + function_jobs_table.select() + .distinct(cast(function_jobs_table.c.inputs, Text)) + .where(filter_conditions) + .order_by( + cast(function_jobs_table.c.inputs, Text), + function_jobs_table.c.created.desc(), + ) + ) - if len(jobs) > 0: - return jobs + # Create a mapping from JSON inputs to jobs + _ensure_str = lambda x: x if isinstance(x, str) else json.dumps(x) + jobs_by_input: dict[str, RegisteredFunctionJobDB] = { + _ensure_str(row.inputs): RegisteredFunctionJobDB.model_validate(row) + for row in results + } - return None + return [jobs_by_input.get(input_, None) for input_ in json_inputs] async def get_function_job_status( @@ -461,3 +518,36 @@ async def update_function_job_outputs( raise FunctionJobIDNotFoundError(function_job_id=function_job_id) return TypeAdapter(FunctionOutputs).validate_python(row.outputs) + + +def _update_class_specific_data( + class_specific_data: dict, + patch: RegisteredFunctionJobPatch, +) -> FunctionClassSpecificData: + if patch.function_class == FunctionClass.PROJECT: + return FunctionClassSpecificData( + project_job_id=( + f"{patch.project_job_id}" + if patch.project_job_id + else class_specific_data.get("project_job_id") + ), + job_creation_task_id=( + f"{patch.job_creation_task_id}" + if patch.job_creation_task_id + else class_specific_data.get("job_creation_task_id") + ), + ) + if patch.function_class == FunctionClass.SOLVER: + return FunctionClassSpecificData( + solver_job_id=( + f"{patch.solver_job_id}" + if patch.solver_job_id + else class_specific_data.get("solver_job_id") + ), + job_creation_task_id=( + f"{patch.job_creation_task_id}" + if patch.job_creation_task_id + else class_specific_data.get("job_creation_task_id") + ), + ) + raise UnsupportedFunctionJobClassError(function_job_class=patch.function_class) diff --git a/services/web/server/src/simcore_service_webserver/functions/_functions_service.py b/services/web/server/src/simcore_service_webserver/functions/_functions_service.py index b7749f9e8e1d..35ddbc54cc6b 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_functions_service.py +++ b/services/web/server/src/simcore_service_webserver/functions/_functions_service.py @@ -3,14 +3,16 @@ from aiohttp import web from models_library.basic_types import IDStr from models_library.functions import ( + BatchCreateRegisteredFunctionJobs, + BatchUpdateRegisteredFunctionJobs, Function, FunctionClass, FunctionClassSpecificData, FunctionDB, FunctionGroupAccessRights, FunctionID, - FunctionInputs, FunctionInputSchema, + FunctionInputsList, FunctionJob, FunctionJobClassSpecificData, FunctionJobCollection, @@ -18,6 +20,9 @@ FunctionJobCollectionsListFilters, FunctionJobDB, FunctionJobID, + FunctionJobList, + FunctionJobPatchRequest, + FunctionJobPatchRequestList, FunctionJobStatus, FunctionOutputs, FunctionOutputSchema, @@ -29,7 +34,6 @@ RegisteredFunctionJob, RegisteredFunctionJobCollection, RegisteredFunctionJobDB, - RegisteredFunctionJobPatch, RegisteredFunctionJobWithStatus, RegisteredFunctionJobWithStatusDB, RegisteredProjectFunction, @@ -40,7 +44,6 @@ RegisteredSolverFunctionJobWithStatus, ) from models_library.functions_errors import ( - FunctionJobPatchModelIncompatibleError, UnsupportedFunctionClassError, UnsupportedFunctionJobClassError, ) @@ -49,6 +52,7 @@ from models_library.rest_ordering import OrderBy from models_library.rest_pagination import PageMetaInfoLimitOffset from models_library.users import UserID +from pydantic import TypeAdapter from servicelib.rabbitmq import RPCRouter from . import ( @@ -92,20 +96,38 @@ async def register_function_job( product_name: ProductName, function_job: FunctionJob, ) -> RegisteredFunctionJob: - encoded_function_job = _encode_functionjob(function_job) - created_function_job_db = await _function_jobs_repository.create_function_job( + encoded_function_jobs = _encode_functionjob(function_job) + created_function_jobs_db = await _function_jobs_repository.create_function_jobs( app=app, user_id=user_id, product_name=product_name, - function_class=encoded_function_job.function_class, - title=encoded_function_job.title, - description=encoded_function_job.description, - function_uid=encoded_function_job.function_uuid, - inputs=encoded_function_job.inputs, - outputs=encoded_function_job.outputs, - class_specific_data=encoded_function_job.class_specific_data, + function_jobs=[encoded_function_jobs], + ) + created_items = created_function_jobs_db.created_items + assert len(created_items) == 1 # nosec + return _decode_functionjob(created_items[0]) + + +async def batch_register_function_jobs( + app: web.Application, + *, + user_id: UserID, + product_name: ProductName, + function_jobs: FunctionJobList, +) -> BatchCreateRegisteredFunctionJobs: + TypeAdapter(FunctionJobList).validate_python(function_jobs) + encoded_function_jobs = [_encode_functionjob(job) for job in function_jobs] + created_function_jobs_db = await _function_jobs_repository.create_function_jobs( + app=app, + user_id=user_id, + product_name=product_name, + function_jobs=encoded_function_jobs, + ) + return BatchCreateRegisteredFunctionJobs( + created_items=[ + _decode_functionjob(job) for job in created_function_jobs_db.created_items + ] ) - return _decode_functionjob(created_function_job_db) async def patch_registered_function_job( @@ -113,30 +135,39 @@ async def patch_registered_function_job( *, user_id: UserID, product_name: ProductName, - function_job_uuid: FunctionJobID, - registered_function_job_patch: RegisteredFunctionJobPatch, + function_job_patch_request: FunctionJobPatchRequest, ) -> RegisteredFunctionJob: - job = await _function_jobs_repository.get_function_job( + + result = await _function_jobs_repository.patch_function_jobs( app=app, user_id=user_id, product_name=product_name, - function_job_id=function_job_uuid, + function_job_patch_requests=[function_job_patch_request], ) - if job.function_class != registered_function_job_patch.function_class: - raise FunctionJobPatchModelIncompatibleError( - function_id=job.function_uuid, - product_name=product_name, - ) + assert len(result.updated_items) == 1 # nosec + return _decode_functionjob(result.updated_items[0]) + - patched_job = _patch_functionjob(job, registered_function_job_patch) +async def batch_patch_registered_function_jobs( + app: web.Application, + *, + user_id: UserID, + product_name: ProductName, + function_job_patch_requests: FunctionJobPatchRequestList, +) -> BatchUpdateRegisteredFunctionJobs: + TypeAdapter(FunctionJobPatchRequestList).validate_python( + function_job_patch_requests + ) - result = await _function_jobs_repository.patch_function_job( + result = await _function_jobs_repository.patch_function_jobs( app=app, user_id=user_id, product_name=product_name, - registered_function_job_db=patched_job, + function_job_patch_requests=function_job_patch_requests, + ) + return BatchUpdateRegisteredFunctionJobs( + updated_items=[_decode_functionjob(job) for job in result.updated_items] ) - return _decode_functionjob(result) async def register_function_job_collection( @@ -423,62 +454,55 @@ async def find_cached_function_jobs( user_id: UserID, product_name: ProductName, function_id: FunctionID, - inputs: FunctionInputs, -) -> list[RegisteredFunctionJob] | None: + inputs: FunctionInputsList, + cached_job_statuses: list[FunctionJobStatus] | None = None, +) -> list[RegisteredFunctionJob | None]: returned_function_jobs = await _function_jobs_repository.find_cached_function_jobs( app=app, user_id=user_id, product_name=product_name, function_id=function_id, inputs=inputs, - ) - if returned_function_jobs is None or len(returned_function_jobs) == 0: - return None - - to_return_function_jobs: list[RegisteredFunctionJob] = [] - for returned_function_job in returned_function_jobs: - if returned_function_job.function_class == FunctionClass.PROJECT: - to_return_function_jobs.append( - RegisteredProjectFunctionJob( - uid=returned_function_job.uuid, - title=returned_function_job.title, - description=returned_function_job.description, - function_uid=returned_function_job.function_uuid, - inputs=returned_function_job.inputs, - outputs=None, - project_job_id=returned_function_job.class_specific_data[ - "project_job_id" - ], - job_creation_task_id=returned_function_job.class_specific_data.get( - "job_creation_task_id" - ), - created_at=returned_function_job.created, - ) - ) - elif returned_function_job.function_class == FunctionClass.SOLVER: - to_return_function_jobs.append( - RegisteredSolverFunctionJob( - uid=returned_function_job.uuid, - title=returned_function_job.title, - description=returned_function_job.description, - function_uid=returned_function_job.function_uuid, - inputs=returned_function_job.inputs, - outputs=None, - solver_job_id=returned_function_job.class_specific_data.get( - "solver_job_id" - ), - job_creation_task_id=returned_function_job.class_specific_data.get( - "job_creation_task_id" - ), - created_at=returned_function_job.created, - ) + cached_job_statuses=cached_job_statuses, + ) + assert len(returned_function_jobs) == len(inputs) # nosec + + def _map_db_model_to_domain_model( + job: RegisteredFunctionJobDB | None, + ) -> RegisteredFunctionJob | None: + if job is None: + return None + if job.function_class == FunctionClass.PROJECT: + return RegisteredProjectFunctionJob( + uid=job.uuid, + title=job.title, + description=job.description, + function_uid=job.function_uuid, + inputs=job.inputs, + outputs=None, + project_job_id=job.class_specific_data["project_job_id"], + job_creation_task_id=job.class_specific_data.get( + "job_creation_task_id" + ), + created_at=job.created, ) - else: - raise UnsupportedFunctionJobClassError( - function_job_class=returned_function_job.function_class + if job.function_class == FunctionClass.SOLVER: + return RegisteredSolverFunctionJob( + uid=job.uuid, + title=job.title, + description=job.description, + function_uid=job.function_uuid, + inputs=job.inputs, + outputs=None, + solver_job_id=job.class_specific_data.get("solver_job_id"), + job_creation_task_id=job.class_specific_data.get( + "job_creation_task_id" + ), + created_at=job.created, ) + raise UnsupportedFunctionJobClassError(function_job_class=job.function_class) - return to_return_function_jobs + return [_map_db_model_to_domain_model(job) for job in returned_function_jobs] async def get_function_input_schema( @@ -930,73 +954,3 @@ def _decode_functionjob_wso( raise UnsupportedFunctionJobClassError( function_job_class=functionjob_db.function_class ) - - -def _patch_functionjob( - function_job_db: RegisteredFunctionJobDB, - patch: RegisteredFunctionJobPatch, -) -> RegisteredFunctionJobDB: - if function_job_db.function_class == FunctionClass.PROJECT: - assert patch.function_class == FunctionClass.PROJECT # nosec - return RegisteredFunctionJobDB( - function_class=FunctionClass.PROJECT, - function_uuid=function_job_db.function_uuid, - title=patch.title or function_job_db.title, - uuid=function_job_db.uuid, - description=patch.description or function_job_db.description, - inputs=patch.inputs or function_job_db.inputs, - outputs=patch.outputs or function_job_db.outputs, - created=function_job_db.created, - class_specific_data=FunctionClassSpecificData( - project_job_id=( - f"{patch.project_job_id}" - if patch.project_job_id - else function_job_db.class_specific_data.get("project_job_id") - ), - job_creation_task_id=( - f"{patch.job_creation_task_id}" - if patch.job_creation_task_id - else function_job_db.class_specific_data.get("job_creation_task_id") - ), - ), - ) - if function_job_db.function_class == FunctionClass.SOLVER: - assert patch.function_class == FunctionClass.SOLVER # nosec - return RegisteredFunctionJobDB( - function_class=FunctionClass.SOLVER, - function_uuid=function_job_db.function_uuid, - title=patch.title or function_job_db.title, - uuid=function_job_db.uuid, - description=patch.description or function_job_db.description, - inputs=patch.inputs or function_job_db.inputs, - outputs=patch.outputs or function_job_db.outputs, - created=function_job_db.created, - class_specific_data=FunctionClassSpecificData( - solver_job_id=( - f"{patch.solver_job_id}" - if patch.solver_job_id - else function_job_db.class_specific_data.get("solver_job_id") - ), - job_creation_task_id=( - f"{patch.job_creation_task_id}" - if patch.job_creation_task_id - else function_job_db.class_specific_data.get("job_creation_task_id") - ), - ), - ) - if function_job_db.function_class == FunctionClass.PYTHON_CODE: - assert patch.function_class == FunctionClass.PYTHON_CODE # nosec - return RegisteredFunctionJobDB( - function_class=FunctionClass.PYTHON_CODE, - function_uuid=function_job_db.function_uuid, - title=patch.title or function_job_db.title, - uuid=function_job_db.uuid, - description=patch.description or function_job_db.description, - inputs=patch.inputs or function_job_db.inputs, - outputs=patch.outputs or function_job_db.outputs, - created=function_job_db.created, - class_specific_data=function_job_db.class_specific_data, - ) - raise UnsupportedFunctionJobClassError( - function_job_class=function_job_db.function_class - ) diff --git a/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_job_collections_controller_rpc.py b/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_job_collections_controller_rpc.py index 32e04bc8936d..2cb6dfc9a0ad 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_job_collections_controller_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_job_collections_controller_rpc.py @@ -17,6 +17,7 @@ Function, FunctionClass, FunctionJobCollectionsListFilters, + FunctionJobList, ) from models_library.functions_errors import ( FunctionJobCollectionReadAccessDeniedError, @@ -25,6 +26,7 @@ FunctionJobIDNotFoundError, ) from models_library.products import ProductName +from pydantic import TypeAdapter from pytest_simcore.helpers.webserver_users import UserInfoDict from servicelib.rabbitmq.rpc_interfaces.webserver.v1 import WebServerRpcClient @@ -53,19 +55,9 @@ async def test_function_job_collection( ) assert registered_function.uid is not None - registered_function_job = ProjectFunctionJob( - function_uid=registered_function.uid, - title="Test Function Job", - description="A test function job", - project_job_id=uuid4(), - inputs={"input1": "value1"}, - outputs={"output1": "result1"}, - job_creation_task_id=None, - ) - # Register the function job - function_job_ids = [] - for _ in range(3): - registered_function_job = ProjectFunctionJob( + # Register the function jobs + function_jobs = [ + ProjectFunctionJob( function_uid=registered_function.uid, title="Test Function Job", description="A test function job", @@ -74,14 +66,20 @@ async def test_function_job_collection( outputs={"output1": "result1"}, job_creation_task_id=None, ) - # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=registered_function_job, + for _ in range(3) + ] + # Register the function jobs + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python(function_jobs), user_id=logged_user["id"], product_name=osparc_product_name, ) - assert registered_job.uid is not None - function_job_ids.append(registered_job.uid) + ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 3 + assert all(job.uid is not None for job in registered_jobs) + function_job_ids = [job.uid for job in registered_jobs] function_job_collection = FunctionJobCollection( title="Test Function Job Collection", @@ -202,11 +200,16 @@ async def test_create_function_job_collection_same_function_job_uuid( job_creation_task_id=None, ) # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=registered_function_job, - user_id=logged_user["id"], - product_name=osparc_product_name, + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=[registered_function_job], + user_id=logged_user["id"], + product_name=osparc_product_name, + ) ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 1 + registered_job = registered_jobs[0] assert registered_job.uid is not None function_job_ids = [registered_job.uid] * 3 @@ -261,9 +264,8 @@ async def test_list_function_job_collections( assert registered_function.uid is not None # Create a function job collection - function_job_ids = [] - for _ in range(3): - registered_function_job = ProjectFunctionJob( + function_jobs = [ + ProjectFunctionJob( function_uid=registered_function.uid, title="Test Function Job", description="A test function job", @@ -272,19 +274,24 @@ async def test_list_function_job_collections( outputs={"output1": "result1"}, job_creation_task_id=None, ) - # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=registered_function_job, + for _ in range(3) + ] + # Register the function jobs + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python(function_jobs), user_id=logged_user["id"], product_name=osparc_product_name, ) - assert registered_job.uid is not None - function_job_ids.append(registered_job.uid) + ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 3 + assert all(job.uid is not None for job in registered_jobs) function_job_collection = FunctionJobCollection( title="Test Function Job Collection", description="A test function job collection", - job_ids=function_job_ids, + job_ids=[job.uid for job in registered_jobs], ) # Register the function job collection @@ -357,9 +364,8 @@ async def test_list_function_job_collections_filtered_function_id( else: function_id = other_registered_function.uid # Create a function job collection - function_job_ids = [] - for _ in range(3): - registered_function_job = ProjectFunctionJob( + function_jobs = [ + ProjectFunctionJob( function_uid=function_id, title="Test Function Job", description="A test function job", @@ -368,14 +374,22 @@ async def test_list_function_job_collections_filtered_function_id( outputs={"output1": "result1"}, job_creation_task_id=None, ) - # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=registered_function_job, + for _ in range(3) + ] + # Register the function job + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python( + function_jobs + ), user_id=logged_user["id"], product_name=osparc_product_name, ) - assert registered_job.uid is not None - function_job_ids.append(registered_job.uid) + ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 3 + assert all(job.uid for job in registered_jobs) + function_job_ids = [job.uid for job in registered_jobs] function_job_collection = FunctionJobCollection( title="Test Function Job Collection", diff --git a/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_jobs_controller_rpc.py b/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_jobs_controller_rpc.py index 0a22743775b5..5c411b9d4305 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_jobs_controller_rpc.py +++ b/services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_jobs_controller_rpc.py @@ -15,6 +15,8 @@ Function, FunctionClass, FunctionJobCollection, + FunctionJobList, + FunctionJobPatchRequest, FunctionJobStatus, RegisteredFunctionJob, RegisteredFunctionJobPatch, @@ -31,6 +33,7 @@ ) from models_library.products import ProductName from models_library.projects import ProjectID +from pydantic import TypeAdapter from pytest_simcore.helpers.webserver_users import UserInfoDict from servicelib.celery.models import TaskKey from servicelib.rabbitmq.rpc_interfaces.webserver.v1 import WebServerRpcClient @@ -73,11 +76,16 @@ async def test_register_get_delete_function_job( ) # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=function_job, - user_id=logged_user["id"], - product_name=osparc_product_name, + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=[function_job], + user_id=logged_user["id"], + product_name=osparc_product_name, + ) ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 1 + registered_job = registered_jobs[0] # Assert the registered job matches the input job assert registered_job.function_uid == function_job.function_uid @@ -191,11 +199,16 @@ async def test_list_function_jobs( ) # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=function_job, - user_id=logged_user["id"], - product_name=osparc_product_name, + registered_jobs_batch_create = ( + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=[function_job], + user_id=logged_user["id"], + product_name=osparc_product_name, + ) ) + registered_jobs = registered_jobs_batch_create.created_items + assert len(registered_jobs) == 1 + registered_job = registered_jobs[0] # List function jobs jobs, _ = await webserver_rpc_client.functions.list_function_jobs( @@ -420,9 +433,8 @@ async def test_find_cached_function_jobs( product_name=osparc_product_name, ) - registered_function_jobs = [] - for value in range(5): - function_job = ProjectFunctionJob( + function_jobs = [ + ProjectFunctionJob( function_uid=registered_function.uid, title="Test Function Job", description="A test function job", @@ -431,40 +443,111 @@ async def test_find_cached_function_jobs( outputs={"output1": "result1"}, job_creation_task_id=None, ) + for value in range(5) + ] - # Register the function job - registered_job = await webserver_rpc_client.functions.register_function_job( - function_job=function_job, - user_id=logged_user["id"], - product_name=osparc_product_name, - ) - registered_function_jobs.append(registered_job) + # Register the function job + await webserver_rpc_client.functions.batch_register_function_jobs( + function_jobs=TypeAdapter(FunctionJobList).validate_python(function_jobs), + user_id=logged_user["id"], + product_name=osparc_product_name, + ) # Find cached function jobs cached_jobs = await webserver_rpc_client.functions.find_cached_function_jobs( function_id=registered_function.uid, - inputs={"input1": 1}, + inputs=[{"input1": 1}, {"input1": 10}], user_id=logged_user["id"], product_name=osparc_product_name, ) # Assert the cached jobs contain the registered job - assert cached_jobs is not None assert len(cached_jobs) == 2 - assert {job.uid for job in cached_jobs} == { - registered_function_jobs[1].uid, - registered_function_jobs[4].uid, - } + job0 = cached_jobs[0] + assert job0 is not None + assert job0.inputs == {"input1": 1} + assert cached_jobs[1] is None cached_jobs = await webserver_rpc_client.functions.find_cached_function_jobs( function_id=registered_function.uid, - inputs={"input1": 1}, + inputs=[{"input1": 1}, {"input1": 10}], user_id=other_logged_user["id"], product_name=osparc_product_name, ) # Assert the cached jobs does not contain the registered job for the other user - assert cached_jobs is None + assert len(cached_jobs) == 2 + assert all(job is None for job in cached_jobs) + + +@pytest.mark.parametrize( + "user_role", + [UserRole.USER], +) +async def test_find_cached_function_jobs_with_status( + client: TestClient, + webserver_rpc_client: WebServerRpcClient, + add_user_function_api_access_rights: None, + logged_user: UserInfoDict, + other_logged_user: UserInfoDict, + osparc_product_name: ProductName, + create_fake_function_obj: Callable[[FunctionClass], Function], + clean_functions: None, +): + # Register the function first + job_statuses = [ + FunctionJobStatus(status="RUNNING"), + FunctionJobStatus(status="FAILED"), + ] + registered_function = await webserver_rpc_client.functions.register_function( + function=create_fake_function_obj(FunctionClass.PROJECT), + user_id=logged_user["id"], + product_name=osparc_product_name, + ) + input_ = {"input1": 1.0} + + for status in job_statuses: + function_job = ProjectFunctionJob( + function_uid=registered_function.uid, + title="Test Function Job", + description="A test function job", + project_job_id=uuid4(), + inputs=input_, + outputs={"output1": "result1"}, + job_creation_task_id=None, + ) + + # Register the function job + registered_job = await webserver_rpc_client.functions.register_function_job( + function_job=function_job, + user_id=logged_user["id"], + product_name=osparc_product_name, + ) + await webserver_rpc_client.functions.update_function_job_status( + user_id=logged_user["id"], + product_name=osparc_product_name, + function_job_id=registered_job.uid, + job_status=status, + ) + + status = job_statuses[0] + cached_jobs = await webserver_rpc_client.functions.find_cached_function_jobs( + function_id=registered_function.uid, + product_name=osparc_product_name, + user_id=logged_user["id"], + inputs=[input_], + cached_job_statuses=[status], + ) + assert len(cached_jobs) == 1 + cached_job = cached_jobs[0] + assert cached_job is not None + assert cached_job.inputs == input_ + cached_job_status = await webserver_rpc_client.functions.get_function_job_status( + product_name=osparc_product_name, + function_job_id=cached_job.uid, + user_id=logged_user["id"], + ) + assert status == cached_job_status @pytest.mark.parametrize( @@ -544,9 +627,10 @@ async def test_patch_registered_function_jobs( registered_job = await webserver_rpc_client.functions.patch_registered_function_job( user_id=logged_user["id"], - function_job_uuid=registered_job.uid, product_name=osparc_product_name, - registered_function_job_patch=patch, + function_job_patch_request=FunctionJobPatchRequest( + uid=registered_job.uid, patch=patch + ), ) assert registered_job.title == patch.title assert registered_job.description == patch.description @@ -600,7 +684,7 @@ async def test_incompatible_patch_model_error( create_fake_function_obj: Callable[[FunctionClass], Function], clean_functions: None, function_job: RegisteredFunctionJob, - patch: RegisteredFunctionJobPatch, + patch: RegisteredProjectFunctionJobPatch | RegisteredSolverFunctionJobPatch, ): function = create_fake_function_obj(function_job.function_class) @@ -619,9 +703,10 @@ async def test_incompatible_patch_model_error( registered_job = ( await webserver_rpc_client.functions.patch_registered_function_job( user_id=logged_user["id"], - function_job_uuid=registered_job.uid, product_name=osparc_product_name, - registered_function_job_patch=patch, + function_job_patch_request=FunctionJobPatchRequest( + uid=registered_job.uid, patch=patch + ), ) ) @@ -674,7 +759,6 @@ async def test_update_function_job_status_output( user_id=logged_user["id"], product_name=osparc_product_name, ) - old_job_status = await webserver_rpc_client.functions.get_function_job_status( function_job_id=registered_job.uid, user_id=logged_user["id"], @@ -764,7 +848,6 @@ async def test_update_function_job_outputs( user_id=logged_user["id"], product_name=osparc_product_name, ) - received_outputs = await webserver_rpc_client.functions.get_function_job_outputs( function_job_id=registered_job.uid, user_id=logged_user["id"],