Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
RegisteredFunctionJobCollection,
)
from models_library.functions import (
FunctionJobStatus,
FunctionOutputs,
FunctionUserAccessRights,
FunctionUserApiAccessRights,
)
Expand Down Expand Up @@ -300,6 +302,82 @@ async def get_function_job(
return TypeAdapter(RegisteredFunctionJob).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
async def get_function_job_status(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
function_job_id: FunctionJobID,
product_name: ProductName,
) -> FunctionJobStatus:
result = await rabbitmq_rpc_client.request(
WEBSERVER_RPC_NAMESPACE,
TypeAdapter(RPCMethodName).validate_python("get_function_job_status"),
function_job_id=function_job_id,
user_id=user_id,
product_name=product_name,
)
return TypeAdapter(FunctionJobStatus).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
async def get_function_job_outputs(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
function_job_id: FunctionJobID,
product_name: ProductName,
) -> FunctionOutputs:
result = await rabbitmq_rpc_client.request(
WEBSERVER_RPC_NAMESPACE,
TypeAdapter(RPCMethodName).validate_python("get_function_job_outputs"),
function_job_id=function_job_id,
user_id=user_id,
product_name=product_name,
)
return TypeAdapter(FunctionOutputs).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
async def update_function_job_status(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
product_name: ProductName,
function_job_id: FunctionJobID,
job_status: FunctionJobStatus,
) -> FunctionJobStatus:
result = await rabbitmq_rpc_client.request(
WEBSERVER_RPC_NAMESPACE,
TypeAdapter(RPCMethodName).validate_python("update_function_job_status"),
function_job_id=function_job_id,
job_status=job_status,
user_id=user_id,
product_name=product_name,
)
return TypeAdapter(FunctionJobStatus).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
async def update_function_job_outputs(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
user_id: UserID,
product_name: ProductName,
function_job_id: FunctionJobID,
outputs: FunctionOutputs,
) -> FunctionOutputs:
result = await rabbitmq_rpc_client.request(
WEBSERVER_RPC_NAMESPACE,
TypeAdapter(RPCMethodName).validate_python("update_function_job_outputs"),
function_job_id=function_job_id,
outputs=outputs,
user_id=user_id,
product_name=product_name,
)
return TypeAdapter(FunctionOutputs).validate_python(result)


@log_decorator(_logger, level=logging.DEBUG)
async def delete_function_job(
rabbitmq_rpc_client: RabbitMQRPCClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
from fastapi_pagination.api import create_page
from models_library.api_schemas_long_running_tasks.tasks import TaskGet
from models_library.api_schemas_webserver.functions import (
Function,
FunctionClass,
FunctionJob,
FunctionJobID,
FunctionJobStatus,
FunctionOutputs,
RegisteredFunctionJob,
)
from models_library.functions import RegisteredFunction
from models_library.functions_errors import (
UnsupportedFunctionClassError,
UnsupportedFunctionFunctionJobClassCombinationError,
)
from models_library.products import ProductName
from models_library.projects_state import RunningState
from models_library.users import UserID
from servicelib.fastapi.dependencies import get_app
from sqlalchemy.ext.asyncio import AsyncEngine
Expand Down Expand Up @@ -184,6 +185,12 @@ async def function_job_status(
user_id=user_id,
product_name=product_name,
)
old_job_status = await wb_api_rpc.get_function_job_status(
function_job_id=function_job.uid, user_id=user_id, product_name=product_name
)

if old_job_status.status in (RunningState.SUCCESS, RunningState.FAILED):
return old_job_status

if (
function.function_class == FunctionClass.PROJECT
Expand All @@ -195,9 +202,7 @@ async def function_job_status(
user_id=user_id,
director2_api=director2_api,
)
return FunctionJobStatus(status=job_status.state)

if (function.function_class == FunctionClass.SOLVER) and (
elif (function.function_class == FunctionClass.SOLVER) and (
function_job.function_class == FunctionClass.SOLVER
):
job_status = await solvers_jobs.inspect_job(
Expand All @@ -207,11 +212,19 @@ async def function_job_status(
user_id=user_id,
director2_api=director2_api,
)
return FunctionJobStatus(status=job_status.state)
else:
raise UnsupportedFunctionFunctionJobClassCombinationError(
function_class=function.function_class,
function_job_class=function_job.function_class,
)

new_job_status = FunctionJobStatus(status=job_status.state)

raise UnsupportedFunctionFunctionJobClassCombinationError(
function_class=function.function_class,
function_job_class=function_job.function_class,
return await wb_api_rpc.update_function_job_status(
function_job_id=function_job.uid,
user_id=user_id,
product_name=product_name,
job_status=new_job_status,
)


Expand All @@ -220,7 +233,7 @@ async def get_function_from_functionjobid(
function_job_id: FunctionJobID,
user_id: Annotated[UserID, Depends(get_current_user_id)],
product_name: Annotated[ProductName, Depends(get_product_name)],
) -> tuple[Function, FunctionJob]:
) -> tuple[RegisteredFunction, RegisteredFunctionJob]:
function_job = await get_function_job(
wb_api_rpc=wb_api_rpc,
function_job_id=function_job_id,
Expand Down Expand Up @@ -250,7 +263,7 @@ async def get_function_from_functionjobid(
changelog=CHANGE_LOGS["function_job_outputs"],
),
)
async def function_job_outputs(
async def get_function_job_outputs(
function_job_id: FunctionJobID,
webserver_api: Annotated[AuthSession, Depends(get_webserver_session)],
user_id: Annotated[UserID, Depends(get_current_user_id)],
Expand All @@ -266,11 +279,18 @@ async def function_job_outputs(
product_name=product_name,
)

old_job_outputs = await wb_api_rpc.get_function_job_outputs(
function_job_id=function_job.uid, user_id=user_id, product_name=product_name
)

if old_job_outputs is not None:
return old_job_outputs

if (
function.function_class == FunctionClass.PROJECT
and function_job.function_class == FunctionClass.PROJECT
):
return dict(
new_outputs = dict(
(
await studies_jobs.get_study_job_outputs(
study_id=function.project_id,
Expand All @@ -281,12 +301,11 @@ async def function_job_outputs(
)
).results
)

if (
elif (
function.function_class == FunctionClass.SOLVER
and function_job.function_class == FunctionClass.SOLVER
):
return dict(
new_outputs = dict(
(
await solvers_jobs_read.get_job_outputs(
solver_key=function.solver_key,
Expand All @@ -299,7 +318,15 @@ async def function_job_outputs(
)
).results
)
raise UnsupportedFunctionClassError(function_class=function.function_class)
else:
raise UnsupportedFunctionClassError(function_class=function.function_class)

return await wb_api_rpc.update_function_job_outputs(
function_job_id=function_job.uid,
user_id=user_id,
product_name=product_name,
outputs=new_outputs,
)


@function_job_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
)
from models_library.api_schemas_webserver.licensed_items import LicensedItemRpcGetPage
from models_library.functions import (
FunctionJobStatus,
FunctionOutputs,
FunctionUserAccessRights,
FunctionUserApiAccessRights,
)
Expand Down Expand Up @@ -472,6 +474,66 @@ async def get_function_output_schema(
function_id=function_id,
)

async def get_function_job_status(
self,
*,
user_id: UserID,
product_name: ProductName,
function_job_id: FunctionJobID,
) -> FunctionJobStatus:
return await functions_rpc_interface.get_function_job_status(
self._client,
user_id=user_id,
product_name=product_name,
function_job_id=function_job_id,
)

async def get_function_job_outputs(
self,
*,
user_id: UserID,
product_name: ProductName,
function_job_id: FunctionJobID,
) -> FunctionOutputs:
return await functions_rpc_interface.get_function_job_outputs(
self._client,
user_id=user_id,
product_name=product_name,
function_job_id=function_job_id,
)

async def update_function_job_status(
self,
*,
function_job_id: FunctionJobID,
user_id: UserID,
product_name: ProductName,
job_status: FunctionJobStatus,
) -> FunctionJobStatus:
return await functions_rpc_interface.update_function_job_status(
self._client,
function_job_id=function_job_id,
user_id=user_id,
product_name=product_name,
job_status=job_status,
)

async def update_function_job_outputs(
self,
*,
function_job_id: FunctionJobID,
user_id: UserID,
product_name: ProductName,
outputs: FunctionOutputs,
) -> FunctionOutputs:
return await functions_rpc_interface.update_function_job_outputs(
self._client,
function_job_id=function_job_id,
user_id=user_id,
product_name=product_name,
outputs=outputs,
)

async def find_cached_function_jobs(
self,
*,
Expand Down
21 changes: 21 additions & 0 deletions services/api-server/tests/unit/api_functions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,24 @@ def _mock(
)

return _mock


@pytest.fixture()
def mock_handler_in_study_jobs_rest_interface(
mock_wb_api_server_rpc: MockerFixture,
) -> Callable[[str, Any, Exception | None], None]:
def _mock(
handler_name: str = "",
return_value: Any = None,
exception: Exception | None = None,
) -> None:
from simcore_service_api_server.api.routes.functions_routes import studies_jobs

mock_wb_api_server_rpc.patch.object(
studies_jobs,
handler_name,
return_value=return_value,
side_effect=exception,
)

return _mock
Loading
Loading