Skip to content

Commit 650986d

Browse files
committed
start fixing find_cache method
1 parent 0ebd500 commit 650986d

File tree

4 files changed

+66
-38
lines changed

4 files changed

+66
-38
lines changed

services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rpc.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from aiohttp import web
44
from models_library.functions import (
55
BatchCreateRegisteredFunctionJobs,
6+
BatchGetCachedRegisteredFunctionJobs,
67
BatchUpdateRegisteredFunctionJobs,
78
Function,
89
FunctionAccessRights,
@@ -490,17 +491,23 @@ async def find_cached_function_jobs(
490491
function_id: FunctionID,
491492
inputs: FunctionInputsList,
492493
status_filter: list[FunctionJobStatus] | None,
493-
) -> list[RegisteredFunctionJob | None]:
494-
jobs = await _functions_service.find_cached_function_jobs(
495-
app=app,
496-
user_id=user_id,
497-
product_name=product_name,
498-
function_id=function_id,
499-
inputs=inputs,
500-
status_filter=status_filter,
501-
)
502-
assert len(jobs) == len(inputs) # nosec
503-
return jobs
494+
) -> BatchGetCachedRegisteredFunctionJobs:
495+
retrieved_cached_function_jobs = (
496+
await _functions_service.batch_find_cached_function_jobs(
497+
app=app,
498+
user_id=user_id,
499+
product_name=product_name,
500+
function_id=function_id,
501+
inputs=inputs,
502+
status_filter=status_filter,
503+
)
504+
)
505+
assert len(retrieved_cached_function_jobs.found_items) + len(
506+
retrieved_cached_function_jobs.missing_identifiers
507+
) == len(
508+
inputs
509+
) # nosec
510+
return retrieved_cached_function_jobs
504511

505512

506513
@router.expose(reraise_if_error_type=(FunctionIDNotFoundError,))

services/web/server/src/simcore_service_webserver/functions/_function_jobs_repository.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from aiohttp import web
88
from models_library.functions import (
99
BatchCreateRegisteredFunctionJobsDB,
10+
BatchGetCachedRegisteredFunctionJobsDB,
1011
BatchUpdateRegisteredFunctionJobsDB,
1112
FunctionClass,
1213
FunctionClassSpecificData,
@@ -318,7 +319,7 @@ async def delete_function_job(
318319
)
319320

320321

321-
async def find_cached_function_jobs(
322+
async def batch_find_cached_function_jobs(
322323
app: web.Application,
323324
connection: AsyncConnection | None = None,
324325
*,
@@ -327,7 +328,7 @@ async def find_cached_function_jobs(
327328
product_name: ProductName,
328329
inputs: FunctionInputsList,
329330
status_filter: list[FunctionJobStatus] | None = None,
330-
) -> list[RegisteredFunctionJobDB | None]:
331+
) -> BatchGetCachedRegisteredFunctionJobsDB:
331332
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
332333
# Get user groups for access check
333334
user_groups = await list_all_user_groups_ids(app, user_id=user_id)
@@ -378,8 +379,16 @@ async def find_cached_function_jobs(
378379
for row in results
379380
}
380381

381-
# Return results in the same order as inputs, with None for missing jobs
382-
return [jobs_by_input.get(json_input, None) for json_input in json_inputs]
382+
return BatchGetCachedRegisteredFunctionJobsDB(
383+
found_items=[
384+
jobs_by_input[input_]
385+
for input_ in json_inputs
386+
if input_ in jobs_by_input
387+
],
388+
missing_identifiers=[
389+
input_ for input_ in inputs if json.dumps(input_) not in jobs_by_input
390+
],
391+
)
383392

384393

385394
async def get_function_job_status(

services/web/server/src/simcore_service_webserver/functions/_functions_service.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from models_library.basic_types import IDStr
55
from models_library.functions import (
66
BatchCreateRegisteredFunctionJobs,
7+
BatchGetCachedRegisteredFunctionJobs,
78
BatchUpdateRegisteredFunctionJobs,
89
Function,
910
FunctionClass,
@@ -448,30 +449,34 @@ async def update_function(
448449
return _decode_function(updated_function)
449450

450451

451-
async def find_cached_function_jobs(
452+
async def batch_find_cached_function_jobs(
452453
app: web.Application,
453454
*,
454455
user_id: UserID,
455456
product_name: ProductName,
456457
function_id: FunctionID,
457458
inputs: FunctionInputsList,
458459
status_filter: list[FunctionJobStatus] | None = None,
459-
) -> list[RegisteredFunctionJob | None]:
460-
returned_function_jobs = await _function_jobs_repository.find_cached_function_jobs(
461-
app=app,
462-
user_id=user_id,
463-
product_name=product_name,
464-
function_id=function_id,
465-
inputs=inputs,
466-
status_filter=status_filter,
460+
) -> BatchGetCachedRegisteredFunctionJobs:
461+
returned_function_jobs = (
462+
await _function_jobs_repository.batch_find_cached_function_jobs(
463+
app=app,
464+
user_id=user_id,
465+
product_name=product_name,
466+
function_id=function_id,
467+
inputs=inputs,
468+
status_filter=status_filter,
469+
)
467470
)
468-
assert len(returned_function_jobs) == len(inputs) # nosec
471+
assert len(returned_function_jobs.found_items) + len(
472+
returned_function_jobs.missing_identifiers
473+
) == len(
474+
inputs
475+
) # nosec
469476

470477
def _map_db_model_to_domain_model(
471-
job: RegisteredFunctionJobDB | None,
472-
) -> RegisteredFunctionJob | None:
473-
if job is None:
474-
return None
478+
job: RegisteredFunctionJobDB,
479+
) -> RegisteredFunctionJob:
475480
if job.function_class == FunctionClass.PROJECT:
476481
return RegisteredProjectFunctionJob(
477482
uid=job.uuid,
@@ -502,7 +507,13 @@ def _map_db_model_to_domain_model(
502507
)
503508
raise UnsupportedFunctionJobClassError(function_job_class=job.function_class)
504509

505-
return [_map_db_model_to_domain_model(job) for job in returned_function_jobs]
510+
return BatchGetCachedRegisteredFunctionJobs(
511+
found_items=[
512+
_map_db_model_to_domain_model(job)
513+
for job in returned_function_jobs.found_items
514+
],
515+
missing_identifiers=returned_function_jobs.missing_identifiers,
516+
)
506517

507518

508519
async def get_function_input_schema(

services/web/server/tests/unit/with_dbs/04/functions/wb-api-server/test_function_jobs_controller_rpc.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,13 @@ async def test_find_cached_function_jobs(
464464
)
465465

466466
# Assert the cached jobs contain the registered job
467-
assert cached_jobs is not None
468-
assert len(cached_jobs) == 2
469-
job0 = cached_jobs[0]
467+
assert cached_jobs.found_items is not None
468+
assert len(cached_jobs.found_items) == 1
469+
job0 = cached_jobs.found_items[0]
470470
assert job0 is not None
471471
assert job0.inputs == {"input1": 1}
472-
assert cached_jobs[1] is None
472+
assert len(cached_jobs.missing_identifiers) == 1
473+
assert cached_jobs.missing_identifiers[0] == {"input1": 10}
473474

474475
cached_jobs = await webserver_rpc_client.functions.find_cached_function_jobs(
475476
function_id=registered_function.uid,
@@ -479,8 +480,8 @@ async def test_find_cached_function_jobs(
479480
)
480481

481482
# Assert the cached jobs does not contain the registered job for the other user
482-
assert len(cached_jobs) == 2
483-
assert all(elm is None for elm in cached_jobs)
483+
assert len(cached_jobs.missing_identifiers) == 2
484+
assert len(cached_jobs.found_items) == 0
484485

485486

486487
@pytest.mark.parametrize(
@@ -541,8 +542,8 @@ async def test_find_cached_function_jobs_with_status(
541542
inputs=[input_],
542543
status_filter=[status],
543544
)
544-
assert len(cached_jobs) == 1
545-
cached_job = cached_jobs[0]
545+
assert len(cached_jobs.found_items) == 1
546+
cached_job = cached_jobs.found_items[0]
546547
assert cached_job is not None
547548
assert cached_job.inputs == input_
548549
cached_job_status = await webserver_rpc_client.functions.get_function_job_status(

0 commit comments

Comments
 (0)