Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -315,25 +315,25 @@ async def delete_function_job(


@log_decorator(_logger, level=logging.DEBUG)
async def find_cached_function_job(
async def find_cached_function_jobs(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
product_name: ProductName,
function_id: FunctionID,
inputs: FunctionInputs,
) -> RegisteredFunctionJob | None:
) -> list[RegisteredFunctionJob] | None:
result = await rabbitmq_rpc_client.request(
WEBSERVER_RPC_NAMESPACE,
TypeAdapter(RPCMethodName).validate_python("find_cached_function_job"),
TypeAdapter(RPCMethodName).validate_python("find_cached_function_jobs"),
function_id=function_id,
inputs=inputs,
user_id=user_id,
product_name=product_name,
)
if result is None:
return None
return TypeAdapter(RegisteredFunctionJob).validate_python(result)
return TypeAdapter(list[RegisteredFunctionJob]).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
UnsupportedFunctionClassError,
)
from models_library.products import ProductName
from models_library.projects_state import RunningState
from models_library.users import UserID
from servicelib.fastapi.dependencies import get_reverse_url_mapper
from simcore_service_api_server._service_jobs import JobService
Expand Down Expand Up @@ -351,6 +352,8 @@ async def run_function( # noqa: PLR0913
job_service: Annotated[JobService, Depends(get_job_service)],
) -> RegisteredFunctionJob:

from .function_jobs_routes import function_job_status

to_run_function = await wb_api_rpc.get_function(
function_id=function_id, user_id=user_id, product_name=product_name
)
Expand All @@ -371,13 +374,22 @@ async def run_function( # noqa: PLR0913
if not is_valid:
raise FunctionInputsValidationError(error=validation_str)

if cached_function_job := await wb_api_rpc.find_cached_function_job(
if cached_function_jobs := await wb_api_rpc.find_cached_function_jobs(
function_id=to_run_function.uid,
inputs=joined_inputs,
user_id=user_id,
product_name=product_name,
):
return cached_function_job
for cached_function_job in cached_function_jobs:
job_status = await function_job_status(
wb_api_rpc=wb_api_rpc,
director2_api=director2_api,
function_job_id=cached_function_job.uid,
user_id=user_id,
product_name=product_name,
)
if job_status.status == RunningState.SUCCESS:
return cached_function_job

if to_run_function.function_class == FunctionClass.PROJECT:
study_job = await studies_jobs.create_study_job(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,15 +468,15 @@ async def get_function_output_schema(
function_id=function_id,
)

async def find_cached_function_job(
async def find_cached_function_jobs(
self,
*,
user_id: UserID,
product_name: ProductName,
function_id: FunctionID,
inputs: FunctionInputs,
) -> RegisteredFunctionJob | None:
return await functions_rpc_interface.find_cached_function_job(
) -> list[RegisteredFunctionJob] | None:
return await functions_rpc_interface.find_cached_function_jobs(
self._client,
user_id=user_id,
product_name=product_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,15 @@ async def update_function_description(


@router.expose()
async def find_cached_function_job(
async def find_cached_function_jobs(
app: web.Application,
*,
user_id: UserID,
product_name: ProductName,
function_id: FunctionID,
inputs: FunctionInputs,
) -> FunctionJob | None:
return await _functions_service.find_cached_function_job(
) -> list[RegisteredFunctionJob] | None:
return await _functions_service.find_cached_function_jobs(
app=app,
user_id=user_id,
product_name=product_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,15 +687,15 @@ async def delete_function_job(
)


async def find_cached_function_job(
async def find_cached_function_jobs(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
function_id: FunctionID,
product_name: ProductName,
inputs: FunctionInputs,
) -> RegisteredFunctionJobDB | None:
) -> list[RegisteredFunctionJobDB] | None:

async with transaction_context(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(
Expand All @@ -704,19 +704,13 @@ async def find_cached_function_job(
cast(function_jobs_table.c.inputs, Text) == json.dumps(inputs),
),
)

rows = await result.all()

if rows is None or len(rows) == 0:
return None

assert len(rows) == 1, (
"More than one function job found with the same function id and inputs."
f" Function id: {function_id}, Inputs: {inputs}"
) # nosec

row = rows[0]
if rows is None or len(rows) == 0:
return None

jobs = []
for row in rows:
job = RegisteredFunctionJobDB.model_validate(dict(row))
try:
await check_user_permissions(
Expand All @@ -729,13 +723,14 @@ async def find_cached_function_job(
permissions=["read"],
)
except FunctionJobReadAccessDeniedError:
# If the user does not have read access, return None
return None
continue

if job.inputs == inputs:
return job
jobs.append(job)

return None
if len(jobs) > 0:
return jobs

return None


async def get_function_job_collection(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,48 +340,60 @@ async def update_function_description(


@router.expose()
async def find_cached_function_job(
async def find_cached_function_jobs(
app: web.Application,
*,
user_id: UserID,
product_name: ProductName,
function_id: FunctionID,
inputs: FunctionInputs,
) -> FunctionJob | None:
returned_function_job = await _functions_repository.find_cached_function_job(
) -> list[RegisteredFunctionJob] | None:
returned_function_jobs = await _functions_repository.find_cached_function_jobs(
app=app,
user_id=user_id,
product_name=product_name,
function_id=function_id,
inputs=inputs,
)
if returned_function_job is None:
if returned_function_jobs is None or len(returned_function_jobs) == 0:
return None

if returned_function_job.function_class == FunctionClass.PROJECT:
return RegisteredProjectFunctionJob(
uid=returned_function_job.uuid,
title=returned_function_job.title,
description=returned_function_job.description,
function_uid=returned_function_job.function_uuid,
inputs=returned_function_job.inputs,
outputs=None,
project_job_id=returned_function_job.class_specific_data["project_job_id"],
)
if returned_function_job.function_class == FunctionClass.SOLVER:
return RegisteredSolverFunctionJob(
uid=returned_function_job.uuid,
title=returned_function_job.title,
description=returned_function_job.description,
function_uid=returned_function_job.function_uuid,
inputs=returned_function_job.inputs,
outputs=None,
solver_job_id=returned_function_job.class_specific_data["solver_job_id"],
)

raise UnsupportedFunctionJobClassError(
function_job_class=returned_function_job.function_class
)
to_return_function_jobs: list[RegisteredFunctionJob] = []
for returned_function_job in returned_function_jobs:
if returned_function_job.function_class == FunctionClass.PROJECT:
to_return_function_jobs.append(
RegisteredProjectFunctionJob(
uid=returned_function_job.uuid,
title=returned_function_job.title,
description=returned_function_job.description,
function_uid=returned_function_job.function_uuid,
inputs=returned_function_job.inputs,
outputs=None,
project_job_id=returned_function_job.class_specific_data[
"project_job_id"
],
)
)
elif returned_function_job.function_class == FunctionClass.SOLVER:
to_return_function_jobs.append(
RegisteredSolverFunctionJob(
uid=returned_function_job.uuid,
title=returned_function_job.title,
description=returned_function_job.description,
function_uid=returned_function_job.function_uuid,
inputs=returned_function_job.inputs,
outputs=None,
solver_job_id=returned_function_job.class_specific_data[
"solver_job_id"
],
)
)
else:
raise UnsupportedFunctionJobClassError(
function_job_class=returned_function_job.function_class
)

return to_return_function_jobs


@router.expose(reraise_if_error_type=(FunctionIDNotFoundError,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,74 @@ async def test_list_function_job_collections_filtered_function_id(
assert collections[1].uid in [
collection.uid for collection in registered_collections
]


@pytest.mark.parametrize(
"user_role",
[UserRole.USER],
)
async def test_find_cached_function_jobs(
client: TestClient,
rpc_client: RabbitMQRPCClient,
logged_user: UserInfoDict,
other_logged_user: UserInfoDict,
osparc_product_name: ProductName,
mock_function: ProjectFunction,
clean_functions: None,
):

# Register the function first
registered_function = await functions_rpc.register_function(
rabbitmq_rpc_client=rpc_client,
function=mock_function,
user_id=logged_user["id"],
product_name=osparc_product_name,
)

registered_function_jobs = []
for value in range(5):
function_job = ProjectFunctionJob(
function_uid=registered_function.uid,
title="Test Function Job",
description="A test function job",
project_job_id=uuid4(),
inputs={"input1": value if value < 4 else 1},
outputs={"output1": "result1"},
)

# Register the function job
registered_job = await functions_rpc.register_function_job(
rabbitmq_rpc_client=rpc_client,
function_job=function_job,
user_id=logged_user["id"],
product_name=osparc_product_name,
)
registered_function_jobs.append(registered_job)

# Find cached function jobs
cached_jobs = await functions_rpc.find_cached_function_jobs(
rabbitmq_rpc_client=rpc_client,
function_id=registered_function.uid,
inputs={"input1": 1},
user_id=logged_user["id"],
product_name=osparc_product_name,
)

# Assert the cached jobs contain the registered job
assert cached_jobs is not None
assert len(cached_jobs) == 2
assert {job.uid for job in cached_jobs} == {
registered_function_jobs[1].uid,
registered_function_jobs[4].uid,
}

cached_jobs = await functions_rpc.find_cached_function_jobs(
rabbitmq_rpc_client=rpc_client,
function_id=registered_function.uid,
inputs={"input1": 1},
user_id=other_logged_user["id"],
product_name=osparc_product_name,
)

# Assert the cached jobs does not contain the registered job for the other user
assert cached_jobs is None
Loading