Skip to content

Commit 981736f

Browse files
committed
add mehod for patching a registered function job to function job service
1 parent 10df578 commit 981736f

File tree

4 files changed

+228
-33
lines changed

4 files changed

+228
-33
lines changed

packages/models-library/src/models_library/functions_errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,8 @@ class FunctionJobCollectionsExecuteApiAccessDeniedError(FunctionBaseError):
163163
class FunctionJobPatchModelIncompatibleError(FunctionBaseError):
164164
msg_template = "Incompatible patch model for Function '{function_id}' in product '{product_name}'."
165165
status_code: int = 422
166+
167+
168+
class FunctionJobCacheNotFoundError(FunctionBaseError):
169+
msg_template: str = "No cached function job found."
170+
status_code: int = 404 # Not Found

services/api-server/src/simcore_service_api_server/_service_function_jobs.py

Lines changed: 176 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing import NamedTuple, overload
23

34
import jsonschema
45
from common_library.exclude import as_dict_exclude_none
@@ -16,11 +17,16 @@
1617
RegisteredFunction,
1718
RegisteredFunctionJob,
1819
RegisteredFunctionJobCollection,
20+
RegisteredProjectFunctionJobPatch,
21+
RegisteredSolverFunctionJobPatch,
1922
SolverFunctionJob,
23+
SolverJobID,
24+
TaskID,
2025
)
2126
from models_library.functions_errors import (
2227
FunctionExecuteAccessDeniedError,
2328
FunctionInputsValidationError,
29+
FunctionJobCacheNotFoundError,
2430
FunctionsExecuteApiAccessDeniedError,
2531
UnsupportedFunctionClassError,
2632
UnsupportedFunctionFunctionJobClassCombinationError,
@@ -43,6 +49,11 @@
4349
from .services_rpc.wb_api_server import WbApiRpcClient
4450

4551

52+
class RegisteredFunctionJobData(NamedTuple):
53+
function_job_id: FunctionJobID
54+
job_inputs: JobInputs
55+
56+
4657
def join_inputs(
4758
default_inputs: FunctionInputs | None,
4859
function_inputs: FunctionInputs | None,
@@ -162,12 +173,29 @@ async def inspect_function_job(
162173
job_status=new_job_status,
163174
)
164175

165-
async def run_function_pre_check(
176+
async def create_function_job_inputs(
166177
self,
167178
*,
168179
function: RegisteredFunction,
169180
function_inputs: FunctionInputs,
170181
) -> JobInputs:
182+
joined_inputs = join_inputs(
183+
function.default_inputs,
184+
function_inputs,
185+
)
186+
return JobInputs(
187+
values=joined_inputs or {},
188+
)
189+
190+
async def get_cached_function_job(
191+
self,
192+
*,
193+
function: RegisteredFunction,
194+
function_inputs: FunctionInputs,
195+
job_inputs: JobInputs,
196+
) -> RegisteredFunctionJob:
197+
"""raises FunctionJobCacheNotFoundError if no cached job is found"""
198+
171199
user_api_access_rights = (
172200
await self._web_rpc_client.get_functions_user_api_access_rights(
173201
user_id=self.user_id, product_name=self.product_name
@@ -190,21 +218,137 @@ async def run_function_pre_check(
190218
function_id=function.uid,
191219
)
192220

193-
joined_inputs = join_inputs(
194-
function.default_inputs,
195-
function_inputs,
196-
)
221+
if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs(
222+
function_id=function.uid,
223+
inputs=job_inputs.values,
224+
user_id=self.user_id,
225+
product_name=self.product_name,
226+
):
227+
for cached_function_job in cached_function_jobs:
228+
job_status = await self.inspect_function_job(
229+
function=function,
230+
function_job=cached_function_job,
231+
)
232+
if job_status.status == RunningState.SUCCESS:
233+
return cached_function_job
234+
235+
raise FunctionJobCacheNotFoundError()
236+
237+
async def create_registered_function_job(
238+
self,
239+
*,
240+
function: RegisteredFunction,
241+
function_inputs: FunctionInputs,
242+
job_inputs: JobInputs,
243+
) -> FunctionJobID:
197244

198245
if function.input_schema is not None:
199246
is_valid, validation_str = await self.validate_function_inputs(
200247
function_id=function.uid,
201-
inputs=joined_inputs,
248+
inputs=job_inputs.values,
202249
)
203250
if not is_valid:
204251
raise FunctionInputsValidationError(error=validation_str)
205252

206-
return JobInputs(
207-
values=joined_inputs or {},
253+
if function.function_class == FunctionClass.PROJECT:
254+
job = await self._web_rpc_client.register_function_job(
255+
function_job=ProjectFunctionJob(
256+
function_uid=function.uid,
257+
title=f"Function job of function {function.uid}",
258+
description=function.description,
259+
inputs=job_inputs.values,
260+
outputs=None,
261+
project_job_id=None,
262+
job_creation_task_id=None,
263+
),
264+
user_id=self.user_id,
265+
product_name=self.product_name,
266+
)
267+
268+
elif function.function_class == FunctionClass.SOLVER:
269+
job = await self._web_rpc_client.register_function_job(
270+
function_job=SolverFunctionJob(
271+
function_uid=function.uid,
272+
title=f"Function job of function {function.uid}",
273+
description=function.description,
274+
inputs=job_inputs.values,
275+
outputs=None,
276+
solver_job_id=None,
277+
job_creation_task_id=None,
278+
),
279+
user_id=self.user_id,
280+
product_name=self.product_name,
281+
)
282+
else:
283+
raise UnsupportedFunctionClassError(
284+
function_class=function.function_class,
285+
)
286+
287+
return job.uid
288+
289+
@overload
290+
async def patch_registered_function_job(
291+
self,
292+
*,
293+
user_id: UserID,
294+
product_name: ProductName,
295+
function_job_id: FunctionJobID,
296+
function_class: FunctionClass,
297+
job_creation_task_id: TaskID | None,
298+
project_job_id: ProjectID | None,
299+
) -> RegisteredFunctionJob: ...
300+
301+
@overload
302+
async def patch_registered_function_job(
303+
self,
304+
*,
305+
user_id: UserID,
306+
product_name: ProductName,
307+
function_job_id: FunctionJobID,
308+
function_class: FunctionClass,
309+
job_creation_task_id: TaskID | None,
310+
solver_job_id: SolverJobID | None,
311+
) -> RegisteredFunctionJob: ...
312+
313+
async def patch_registered_function_job(
314+
self,
315+
*,
316+
user_id: UserID,
317+
product_name: ProductName,
318+
function_job_id: FunctionJobID,
319+
function_class: FunctionClass,
320+
job_creation_task_id: TaskID | None,
321+
project_job_id: ProjectID | None = None,
322+
solver_job_id: SolverJobID | None = None,
323+
) -> RegisteredFunctionJob:
324+
# Only allow one of project_job_id or solver_job_id depending on function_class
325+
if function_class == FunctionClass.PROJECT:
326+
patch = RegisteredProjectFunctionJobPatch(
327+
title=None,
328+
description=None,
329+
inputs=None,
330+
outputs=None,
331+
job_creation_task_id=job_creation_task_id,
332+
project_job_id=project_job_id,
333+
)
334+
elif function_class == FunctionClass.SOLVER:
335+
patch = RegisteredSolverFunctionJobPatch(
336+
title=None,
337+
description=None,
338+
inputs=None,
339+
outputs=None,
340+
job_creation_task_id=job_creation_task_id,
341+
solver_job_id=solver_job_id,
342+
)
343+
else:
344+
raise UnsupportedFunctionClassError(
345+
function_class=function_class,
346+
)
347+
return await self._web_rpc_client.patch_registered_function_job(
348+
user_id=user_id,
349+
product_name=product_name,
350+
function_job_id=function_job_id,
351+
registered_function_job_patch=patch,
208352
)
209353

210354
async def run_function(
@@ -217,20 +361,7 @@ async def run_function(
217361
x_simcore_parent_project_uuid: NodeID | None,
218362
x_simcore_parent_node_id: NodeID | None,
219363
) -> RegisteredFunctionJob:
220-
221-
if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs(
222-
function_id=function.uid,
223-
inputs=job_inputs.values,
224-
user_id=self.user_id,
225-
product_name=self.product_name,
226-
):
227-
for cached_function_job in cached_function_jobs:
228-
job_status = await self.inspect_function_job(
229-
function=function,
230-
function_job=cached_function_job,
231-
)
232-
if job_status.status == RunningState.SUCCESS:
233-
return cached_function_job
364+
"""N.B. this function does not check access rights. Use get_cached_function_job for that"""
234365

235366
if function.function_class == FunctionClass.PROJECT:
236367
study_job = await self._job_service.create_studies_job(
@@ -306,7 +437,7 @@ async def map_function(
306437
) -> RegisteredFunctionJobCollection:
307438

308439
job_inputs = [
309-
await self.run_function_pre_check(
440+
await self.create_registered_function_job(
310441
function=function,
311442
function_inputs=inputs,
312443
)
@@ -335,3 +466,25 @@ async def map_function(
335466
user_id=self.user_id,
336467
product_name=self.product_name,
337468
)
469+
function_jobs = [
470+
await self.run_function(
471+
function=function,
472+
job_inputs=inputs,
473+
pricing_spec=pricing_spec,
474+
job_links=job_links,
475+
x_simcore_parent_project_uuid=x_simcore_parent_project_uuid,
476+
x_simcore_parent_node_id=x_simcore_parent_node_id,
477+
)
478+
for inputs in job_inputs
479+
]
480+
481+
function_job_collection_description = f"Function job collection of map of function {function.uid} with {len(function_inputs_list)} inputs"
482+
return await self._web_rpc_client.register_function_job_collection(
483+
function_job_collection=FunctionJobCollection(
484+
title="Function job collection of function map",
485+
description=function_job_collection_description,
486+
job_ids=[function_job.uid for function_job in function_jobs],
487+
),
488+
user_id=self.user_id,
489+
product_name=self.product_name,
490+
)

services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
RegisteredFunctionJobCollection,
1818
)
1919
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
20+
from models_library.functions_errors import FunctionJobCacheNotFoundError
2021
from models_library.products import ProductName
2122
from models_library.projects import ProjectID
2223
from models_library.projects_nodes_io import NodeID
2324
from models_library.users import UserID
2425
from servicelib.celery.models import TaskFilter, TaskMetadata, TasksQueue
2526
from servicelib.fastapi.dependencies import get_reverse_url_mapper
26-
from servicelib.long_running_tasks.models import TaskGet
2727

2828
from ..._service_function_jobs import FunctionJobService
2929
from ..._service_functions import FunctionService
@@ -314,7 +314,7 @@ async def validate_function_inputs(
314314

315315
@function_router.post(
316316
"/{function_id:uuid}:run",
317-
response_model=TaskGet,
317+
response_model=RegisteredFunctionJob,
318318
responses={**_COMMON_FUNCTION_ERROR_RESPONSES},
319319
description=create_route_description(
320320
base="Run function",
@@ -333,7 +333,7 @@ async def run_function( # noqa: PLR0913
333333
],
334334
x_simcore_parent_project_uuid: Annotated[ProjectID | Literal["null"], Header()],
335335
x_simcore_parent_node_id: Annotated[NodeID | Literal["null"], Header()],
336-
) -> TaskGet:
336+
) -> RegisteredFunctionJob:
337337
task_manager = get_task_manager(request.app)
338338
parent_project_uuid = (
339339
x_simcore_parent_project_uuid
@@ -348,10 +348,28 @@ async def run_function( # noqa: PLR0913
348348
pricing_spec = JobPricingSpecification.create_from_headers(request.headers)
349349
job_links = await function_service.get_function_job_links(to_run_function, url_for)
350350

351-
job_inputs = await function_job_service.run_function_pre_check(
351+
job_inputs = await function_job_service.create_function_job_inputs(
352352
function=to_run_function, function_inputs=function_inputs
353353
)
354+
try:
355+
# checks access rights
356+
return await function_job_service.get_cached_function_job(
357+
function=to_run_function,
358+
function_inputs=function_inputs,
359+
job_inputs=job_inputs,
360+
)
361+
except FunctionJobCacheNotFoundError:
362+
pass
363+
364+
pre_registered_function_job_id = (
365+
await function_job_service.create_registered_function_job(
366+
function=to_run_function,
367+
function_inputs=function_inputs,
368+
job_inputs=job_inputs,
369+
)
370+
)
354371

372+
# run function in celery task
355373
job_filter = AsyncJobFilter(
356374
user_id=user_identity.user_id,
357375
product_name=user_identity.product_name,
@@ -376,12 +394,14 @@ async def run_function( # noqa: PLR0913
376394
x_simcore_parent_node_id=parent_node_id,
377395
)
378396

379-
return TaskGet(
380-
task_id=f"{task_uuid}",
381-
task_name=task_name,
382-
status_href=url_for("get_task_status", task_id=task_uuid),
383-
result_href=url_for("get_task_result", task_id=task_uuid),
384-
abort_href=url_for("cancel_task", task_id=task_uuid),
397+
return await function_job_service.patch_registered_function_job(
398+
user_id=user_identity.user_id,
399+
product_name=user_identity.product_name,
400+
function_job_id=pre_registered_function_job_id,
401+
registered_function_job_patch=RegisteredFunctionJobPatch(
402+
status=RunningState.RUNNING,
403+
task_id=task_uuid,
404+
),
385405
)
386406

387407

services/api-server/src/simcore_service_api_server/services_rpc/wb_api_server.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
FunctionOutputs,
2929
FunctionUserAccessRights,
3030
FunctionUserApiAccessRights,
31+
RegisteredFunctionJobPatch,
3132
)
3233
from models_library.licenses import LicensedItemID
3334
from models_library.products import ProductName
@@ -489,6 +490,22 @@ async def register_function_job(
489490
function_job=function_job,
490491
)
491492

493+
async def patch_registered_function_job(
494+
self,
495+
*,
496+
user_id: UserID,
497+
product_name: ProductName,
498+
function_job_id: FunctionJobID,
499+
registered_function_job_patch: RegisteredFunctionJobPatch,
500+
) -> RegisteredFunctionJob:
501+
return await functions_rpc_interface.patch_registered_function_job(
502+
self._client,
503+
user_id=user_id,
504+
product_name=product_name,
505+
function_job_uuid=function_job_id,
506+
registered_function_job_patch=registered_function_job_patch,
507+
)
508+
492509
async def get_function_input_schema(
493510
self, *, user_id: UserID, product_name: ProductName, function_id: FunctionID
494511
) -> FunctionInputSchema:

0 commit comments

Comments
 (0)