Skip to content

Commit a0fab96

Browse files
committed
propagate task creation all the way out
1 parent e91465d commit a0fab96

File tree

4 files changed

+69
-88
lines changed

4 files changed

+69
-88
lines changed

services/api-server/src/simcore_service_api_server/_service_function_jobs.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434
from ._service_jobs import JobService
3535
from .models.api_resources import JobLinks
3636
from .models.domain.functions import (
37+
FunctionJobPatch,
3738
PreRegisteredFunctionJobData,
38-
ProjectFunctionJobPatch,
39-
SolverFunctionJobPatch,
4039
)
4140
from .models.schemas.jobs import JobInputs, JobPricingSpecification
4241
from .services_http.webserver import AuthSession
@@ -209,14 +208,13 @@ async def patch_registered_function_job(
209208
user_id: UserID,
210209
product_name: ProductName,
211210
patches: Annotated[
212-
list[ProjectFunctionJobPatch] | list[SolverFunctionJobPatch],
211+
list[FunctionJobPatch],
213212
Field(max_length=50, min_length=1),
214213
],
215214
) -> list[RegisteredFunctionJob]:
216215
patch_inputs = []
217216
for patch in patches:
218217
if patch.function_class == FunctionClass.PROJECT:
219-
assert isinstance(patch, ProjectFunctionJobPatch) # nosec
220218
patch_inputs.append(
221219
RegisteredProjectFunctionJobPatch(
222220
title=None,
@@ -228,7 +226,6 @@ async def patch_registered_function_job(
228226
)
229227
)
230228
elif patch.function_class == FunctionClass.SOLVER:
231-
assert isinstance(patch, SolverFunctionJobPatch) # nosec
232229
patch_inputs.append(
233230
RegisteredSolverFunctionJobPatch(
234231
title=None,
@@ -279,7 +276,8 @@ async def run_function(
279276
user_id=self.user_id,
280277
product_name=self.product_name,
281278
patches=[
282-
ProjectFunctionJobPatch(
279+
FunctionJobPatch(
280+
function_class=FunctionClass.PROJECT,
283281
function_job_id=pre_registered_function_job_data.function_job_id,
284282
job_creation_task_id=None,
285283
project_job_id=study_job.id,
@@ -309,7 +307,8 @@ async def run_function(
309307
user_id=self.user_id,
310308
product_name=self.product_name,
311309
patches=[
312-
SolverFunctionJobPatch(
310+
FunctionJobPatch(
311+
function_class=FunctionClass.SOLVER,
313312
function_job_id=pre_registered_function_job_data.function_job_id,
314313
job_creation_task_id=None,
315314
solver_job_id=solver_job.id,

services/api-server/src/simcore_service_api_server/_service_function_jobs_task_client.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,16 @@
3838
)
3939
from sqlalchemy.ext.asyncio import AsyncEngine
4040

41+
from ._meta import APP_NAME
4142
from ._service_function_jobs import FunctionJobService
4243
from ._service_functions import FunctionService
4344
from ._service_jobs import JobService
4445
from .api.dependencies.authentication import Identity
45-
from .exceptions.function_errors import (
46-
FunctionJobCacheNotFoundError,
47-
)
4846
from .models.api_resources import JobLinks
4947
from .models.domain.celery_models import ApiServerOwnerMetadata
48+
from .models.domain.functions import FunctionJobPatch
5049
from .models.schemas.functions import FunctionJobCreationTaskStatus
51-
from .models.schemas.jobs import JobInputs, JobPricingSpecification
50+
from .models.schemas.jobs import JobPricingSpecification
5251
from .services_http.webserver import AuthSession
5352
from .services_rpc.storage import StorageService
5453
from .services_rpc.wb_api_server import WbApiRpcClient
@@ -238,29 +237,6 @@ async def inspect_function_job(
238237
check_write_permissions=False,
239238
)
240239

241-
async def get_cached_function_job(
242-
self,
243-
*,
244-
function: RegisteredFunction,
245-
job_inputs: JobInputs,
246-
) -> RegisteredFunctionJob:
247-
"""Raises FunctionJobCacheNotFoundError if no cached job is found"""
248-
if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs(
249-
function_id=function.uid,
250-
inputs=job_inputs.values,
251-
user_id=self.user_id,
252-
product_name=self.product_name,
253-
):
254-
for cached_function_job in cached_function_jobs:
255-
job_status = await self.inspect_function_job(
256-
function=function,
257-
function_job=cached_function_job,
258-
)
259-
if job_status.status == RunningState.SUCCESS:
260-
return cached_function_job
261-
262-
raise FunctionJobCacheNotFoundError
263-
264240
async def function_job_outputs(
265241
self,
266242
*,
@@ -319,7 +295,7 @@ async def function_job_outputs(
319295
check_write_permissions=False,
320296
)
321297

322-
async def create_function_job_creation_task(
298+
async def create_function_job_creation_tasks(
323299
self,
324300
*,
325301
function: RegisteredFunction,
@@ -329,7 +305,7 @@ async def create_function_job_creation_task(
329305
job_links: JobLinks,
330306
parent_project_uuid: ProjectID | None = None,
331307
parent_node_id: NodeID | None = None,
332-
) -> RegisteredFunctionJob:
308+
) -> list[RegisteredFunctionJob]:
333309
inputs = [
334310
self._function_job_service.create_function_job_inputs(
335311
function=function, function_inputs=input_
@@ -359,9 +335,10 @@ async def create_function_job_creation_task(
359335
)
360336

361337
# run function in celery task
362-
363338
owner_metadata = ApiServerOwnerMetadata(
364-
user_id=user_identity.user_id, product_name=user_identity.product_name
339+
user_id=user_identity.user_id,
340+
product_name=user_identity.product_name,
341+
owner=APP_NAME,
365342
)
366343
task_uuids = [
367344
await self._celery_task_manager.submit_task(
@@ -382,10 +359,22 @@ async def create_function_job_creation_task(
382359
for pre_registered_function_job_data in pre_registered_function_job_data_list
383360
]
384361

385-
return await self._function_job_service.patch_registered_function_job(
362+
patched_jobs = await self._function_job_service.patch_registered_function_job(
386363
user_id=user_identity.user_id,
387364
product_name=user_identity.product_name,
388-
function_job_id=pre_registered_function_job_data_list.function_job_id,
389-
function_class=function.function_class,
390-
job_creation_task_id=TaskID(task_uuid),
365+
patches=[
366+
FunctionJobPatch(
367+
function_class=function.function_class,
368+
function_job_id=pre_registered_function_job_data.function_job_id,
369+
job_creation_task_id=TaskID(task_uuid),
370+
project_job_id=None,
371+
solver_job_id=None,
372+
)
373+
for task_uuid, pre_registered_function_job_data in zip(
374+
task_uuids, pre_registered_function_job_data_list
375+
)
376+
],
391377
)
378+
patched_jobs_iter = iter(patched_jobs)
379+
_ = lambda job: job if job is not None else next(patched_jobs_iter)
380+
return [_(job) for job in cached_jobs]

services/api-server/src/simcore_service_api_server/api/routes/functions_routes.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
RegisteredFunctionJob,
1717
RegisteredFunctionJobCollection,
1818
)
19-
from models_library.functions import FunctionJobCollection, FunctionJobID
19+
from models_library.functions import FunctionJobCollection
2020
from models_library.products import ProductName
2121
from models_library.projects import ProjectID
2222
from models_library.projects_nodes_io import NodeID
2323
from models_library.users import UserID
2424
from servicelib.fastapi.dependencies import get_reverse_url_mapper
25-
from servicelib.utils import limited_gather
2625

2726
from ..._service_function_jobs import FunctionJobService
2827
from ..._service_function_jobs_task_client import FunctionJobTaskClientService
@@ -350,15 +349,17 @@ async def run_function(
350349
)
351350
job_links = await function_service.get_function_job_links(to_run_function, url_for)
352351

353-
return await function_job_task_client_service.create_function_job_creation_task(
352+
jobs = await function_job_task_client_service.create_function_job_creation_tasks(
354353
function=to_run_function,
355-
function_inputs=function_inputs,
354+
function_inputs=[function_inputs],
356355
user_identity=user_identity,
357356
pricing_spec=pricing_spec,
358357
job_links=job_links,
359358
parent_project_uuid=parent_project_uuid,
360359
parent_node_id=parent_node_id,
361360
)
361+
assert len(jobs) == 1 # nosec
362+
return jobs[0]
362363

363364

364365
@function_router.delete(
@@ -429,42 +430,23 @@ async def map_function(
429430
)
430431
job_links = await function_service.get_function_job_links(to_run_function, url_for)
431432

432-
async def _run_single_function(function_inputs: FunctionInputs) -> FunctionJobID:
433-
result = (
434-
await function_job_task_client_service.create_function_job_creation_task(
435-
function=to_run_function,
436-
function_inputs=function_inputs,
437-
user_identity=user_identity,
438-
pricing_spec=pricing_spec,
439-
job_links=job_links,
440-
parent_project_uuid=parent_project_uuid,
441-
parent_node_id=parent_node_id,
442-
)
443-
)
444-
return result.uid
445-
446-
# Run all tasks concurrently, allowing them to complete even if some fail
447-
results = await limited_gather(
448-
*[
449-
_run_single_function(function_inputs)
450-
for function_inputs in function_inputs_list
451-
],
452-
reraise=False,
453-
limit=1,
433+
jobs = await function_job_task_client_service.create_function_job_creation_tasks(
434+
function=to_run_function,
435+
function_inputs=function_inputs_list,
436+
user_identity=user_identity,
437+
pricing_spec=pricing_spec,
438+
job_links=job_links,
439+
parent_project_uuid=parent_project_uuid,
440+
parent_node_id=parent_node_id,
454441
)
455442

456-
# Check if any tasks raised exceptions and raise the first one found
457-
for result in results:
458-
if isinstance(result, BaseException):
459-
raise result
460-
461443
# At this point, all results are FunctionJobID since we've checked for exceptions
462444
function_job_collection_description = f"Function job collection of map of function {to_run_function.uid} with {len(function_inputs_list)} inputs"
463445
return await web_api_rpc_client.register_function_job_collection(
464446
function_job_collection=FunctionJobCollection(
465447
title="Function job collection of function map",
466448
description=function_job_collection_description,
467-
job_ids=results, # type: ignore
449+
job_ids=jobs, # type: ignore
468450
),
469451
user_id=user_identity.user_id,
470452
product_name=user_identity.product_name,
Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import NamedTuple
2-
31
from models_library.functions import (
42
FunctionClass,
53
FunctionJobID,
@@ -9,7 +7,7 @@
97
TaskID,
108
)
119
from models_library.projects import ProjectID
12-
from pydantic import BaseModel
10+
from pydantic import BaseModel, model_validator
1311

1412
from ...models.pagination import Page
1513
from ...models.schemas.jobs import JobInputs
@@ -28,15 +26,28 @@ class PageRegisteredFunctionJobWithorWithoutStatus(
2826
pass
2927

3028

31-
class ProjectFunctionJobPatch(NamedTuple):
32-
function_class = FunctionClass.PROJECT
33-
function_job_id: FunctionJobID
34-
job_creation_task_id: TaskID | None
35-
project_job_id: ProjectID | None
36-
37-
38-
class SolverFunctionJobPatch(NamedTuple):
39-
function_class = FunctionClass.SOLVER
29+
class FunctionJobPatch(BaseModel):
30+
function_class: FunctionClass
4031
function_job_id: FunctionJobID
41-
job_creation_task_id: TaskID | None
42-
solver_job_id: SolverJobID | None
32+
job_creation_task_id: TaskID | None = None
33+
project_job_id: ProjectID | None = None
34+
solver_job_id: SolverJobID | None = None
35+
36+
@model_validator(mode="after")
37+
def validate_function_class_consistency(self) -> "FunctionJobPatch":
38+
"""Validate consistency between function_class and job IDs."""
39+
if (
40+
self.solver_job_id is not None
41+
and self.function_class != FunctionClass.SOLVER
42+
):
43+
msg = f"solver_job_id must be None when function_class is {self.function_class}, expected {FunctionClass.SOLVER}"
44+
raise ValueError(msg)
45+
46+
if (
47+
self.project_job_id is not None
48+
and self.function_class != FunctionClass.PROJECT
49+
):
50+
msg = f"project_job_id must be None when function_class is {self.function_class}, expected {FunctionClass.PROJECT}"
51+
raise ValueError(msg)
52+
53+
return self

0 commit comments

Comments
 (0)