diff --git a/packages/models-library/src/models_library/api_schemas_webserver/functions.py b/packages/models-library/src/models_library/api_schemas_webserver/functions.py index d7883745754..a52966a8d18 100644 --- a/packages/models-library/src/models_library/api_schemas_webserver/functions.py +++ b/packages/models-library/src/models_library/api_schemas_webserver/functions.py @@ -1,7 +1,7 @@ import datetime from typing import Annotated, TypeAlias -from pydantic import Field +from pydantic import Field, HttpUrl from ..functions import ( Function, @@ -116,6 +116,10 @@ class RegisteredSolverFunctionGet(RegisteredSolverFunction, OutputSchema): uid: Annotated[FunctionID, Field(alias="uuid")] + created_at: Annotated[datetime.datetime, Field(alias="creationDate")] + modified_at: Annotated[datetime.datetime, Field(alias="lastChangeDate")] + access_rights: FunctionAccessRights | None = None + thumbnail: HttpUrl | None = None class RegisteredProjectFunctionGet(RegisteredProjectFunction, OutputSchema): @@ -124,7 +128,7 @@ class RegisteredProjectFunctionGet(RegisteredProjectFunction, OutputSchema): created_at: Annotated[datetime.datetime, Field(alias="creationDate")] modified_at: Annotated[datetime.datetime, Field(alias="lastChangeDate")] access_rights: FunctionAccessRights | None = None - thumbnail: str | None = None + thumbnail: HttpUrl | None = None class SolverFunctionToRegister(SolverFunction, InputSchema): ... diff --git a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml index fc91d1d393f..a8b1de9be4d 100644 --- a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml +++ b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml @@ -16078,6 +16078,9 @@ components: thumbnail: anyOf: - type: string + maxLength: 2083 + minLength: 1 + format: uri - type: 'null' title: Thumbnail type: object @@ -16131,14 +16134,14 @@ components: type: string format: uuid title: Uuid - createdAt: + creationDate: type: string format: date-time - title: Createdat - modifiedAt: + title: Creationdate + lastChangeDate: type: string format: date-time - title: Modifiedat + title: Lastchangedate solverKey: type: string pattern: ^simcore/services/((comp|dynamic|frontend))/([a-z0-9][a-z0-9_.-]*/)*([a-z0-9-_]+[a-z0-9])$ @@ -16147,14 +16150,26 @@ components: type: string pattern: ^(0|[1-9]\d*)(\.(0|[1-9]\d*)){2}(-(0|[1-9]\d*|\d*[-a-zA-Z][-\da-zA-Z]*)(\.(0|[1-9]\d*|\d*[-a-zA-Z][-\da-zA-Z]*))*)?(\+[-\da-zA-Z]+(\.[-\da-zA-Z-]+)*)?$ title: Solverversion + accessRights: + anyOf: + - $ref: '#/components/schemas/FunctionAccessRights' + - type: 'null' + thumbnail: + anyOf: + - type: string + maxLength: 2083 + minLength: 1 + format: uri + - type: 'null' + title: Thumbnail type: object required: - inputSchema - outputSchema - defaultInputs - uuid - - createdAt - - modifiedAt + - creationDate + - lastChangeDate - solverKey - solverVersion title: RegisteredSolverFunctionGet diff --git a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rest.py b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rest.py index d450ef5f5f1..54363cbdb12 100644 --- a/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rest.py +++ b/services/web/server/src/simcore_service_webserver/functions/_controller/_functions_rest.py @@ -1,3 +1,5 @@ +from typing import Any + from aiohttp import web from models_library.api_schemas_webserver.functions import ( Function, @@ -5,10 +7,15 @@ RegisteredFunction, RegisteredFunctionGet, RegisteredFunctionUpdate, - RegisteredProjectFunctionGet, ) from models_library.api_schemas_webserver.users import MyFunctionPermissionsGet -from models_library.functions import FunctionClass, RegisteredProjectFunction +from models_library.functions import ( + FunctionAccessRights, + FunctionClass, + FunctionID, + RegisteredProjectFunction, + RegisteredSolverFunction, +) from models_library.products import ProductName from models_library.rest_pagination import Page from models_library.rest_pagination_utils import paginate_data @@ -29,6 +36,8 @@ from ...security.decorators import permission_required from ...utils_aiohttp import create_json_response_from_page, envelope_json_response from .. import _functions_service +from .._services_metadata import proxy as _services_metadata_proxy +from .._services_metadata.proxy import ServiceMetadata from ._functions_rest_exceptions import handle_rest_requests_exceptions from ._functions_rest_schemas import ( FunctionGetQueryParams, @@ -39,30 +48,71 @@ routes = web.RouteTableDef() -async def _add_extras_to_project_function( - function: RegisteredProjectFunction, +async def _build_function_access_rights( app: web.Application, user_id: UserID, product_name: ProductName, -) -> dict: - assert isinstance(function, RegisteredProjectFunction) # nosec - - project_dict = await _projects_service.get_project_for_user( + function_id: FunctionID, +) -> FunctionAccessRights: + access_rights = await _functions_service.get_function_user_permissions( app=app, - project_uuid=f"{function.project_id}", user_id=user_id, + product_name=product_name, + function_id=function_id, ) - function_with_extras = function.model_dump(mode="json") | { - "access_rights": await _functions_service.get_function_user_permissions( - app, - user_id=user_id, - product_name=product_name, - function_id=function.uid, - ), - "thumbnail": project_dict.get("thumbnail", None), - } - return function_with_extras + return FunctionAccessRights( + read=access_rights.read, + write=access_rights.write, + execute=access_rights.execute, + ) + + +def _build_project_function_extras_dict( + project: ProjectDBGet, +) -> dict[str, Any]: + extras: dict[str, Any] = {} + if thumbnail := project.thumbnail: + extras["thumbnail"] = thumbnail + return extras + + +def _build_solver_function_extras_dict( + service_metadata: ServiceMetadata, +) -> dict[str, Any]: + + extras: dict[str, Any] = {} + if thumbnail := service_metadata.thumbnail: + extras["thumbnail"] = thumbnail + return extras + + +async def _build_function_extras( + app: web.Application, *, function: RegisteredFunction +) -> dict[str, Any]: + extras: dict[str, Any] = {} + match function.function_class: + case FunctionClass.PROJECT: + assert isinstance(function, RegisteredProjectFunction) + projects = await _projects_service.batch_get_projects( + app=app, + project_uuids=[function.project_id], + ) + if project := projects.get(function.project_id): + extras |= _build_project_function_extras_dict( + project=project, + ) + case FunctionClass.SOLVER: + assert isinstance(function, RegisteredSolverFunction) + services_metadata = await _services_metadata_proxy.get_service_metadata( + app, + key=function.solver_key, + version=function.solver_version, + ) + extras |= _build_solver_function_extras_dict( + service_metadata=services_metadata, + ) + return extras @routes.post(f"/{VTAG}/functions", name="register_function") @@ -117,54 +167,70 @@ async def list_functions(request: web.Request) -> web.Response: ) chunk: list[RegisteredFunctionGet] = [] - projects_map: dict[str, ProjectDBGet | None] = ( - {} - ) # ProjectDBGet has to be renamed at some point! + + extras_map: dict[FunctionID, dict[str, Any]] = {} if query_params.include_extras: - project_ids = [] - for function in functions: - if function.function_class == FunctionClass.PROJECT: - assert isinstance(function, RegisteredProjectFunction) - project_ids.append(function.project_id) - - projects_map = { - f"{p.uuid}": p - for p in await _projects_service.batch_get_projects( + if any( + function.function_class == FunctionClass.PROJECT for function in functions + ): + project_uuids = [ + function.project_id + for function in functions + if function.function_class == FunctionClass.PROJECT + ] + projects_cache = await _projects_service.batch_get_projects( request.app, - project_uuids=project_ids, + project_uuids=project_uuids, ) - } + for function in functions: + if function.function_class == FunctionClass.PROJECT: + project = projects_cache.get(function.project_id) + if not project: + continue + extras_map[function.uid] = _build_project_function_extras_dict( + project=project + ) - for function in functions: - if ( - query_params.include_extras - and function.function_class == FunctionClass.PROJECT + if any( + function.function_class == FunctionClass.SOLVER for function in functions ): - assert isinstance(function, RegisteredProjectFunction) # nosec - if project := projects_map.get(f"{function.project_id}"): - chunk.append( - TypeAdapter(RegisteredProjectFunctionGet).validate_python( - function.model_dump(mode="json") - | { - "access_rights": await _functions_service.get_function_user_permissions( - request.app, - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - function_id=function.uid, - ), - "thumbnail": ( - f"{project.thumbnail}" if project.thumbnail else None - ), - } - ) - ) - else: - chunk.append( - TypeAdapter(RegisteredFunctionGet).validate_python( - function.model_dump(mode="json") + service_keys_and_versions = { + (function.solver_key, function.solver_version) + for function in functions + if function.function_class == FunctionClass.SOLVER + } + service_metadata_cache = ( + await _services_metadata_proxy.batch_get_service_metadata( + app=request.app, keys_and_versions=service_keys_and_versions ) ) + for function in functions: + if function.function_class == FunctionClass.SOLVER: + service_metadata = service_metadata_cache.get( + (function.solver_key, function.solver_version) + ) + if not service_metadata: + continue + extras_map[function.uid] = _build_solver_function_extras_dict( + service_metadata=service_metadata + ) + + for function in functions: + access_rights = await _build_function_access_rights( + request.app, + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + function_id=function.uid, + ) + + extras = extras_map.get(function.uid, {}) + + chunk.append( + TypeAdapter(RegisteredFunctionGet).validate_python( + function.model_dump() | {"access_rights": access_rights, **extras} + ) + ) page = Page[RegisteredFunctionGet].model_validate( paginate_data( @@ -194,33 +260,29 @@ async def get_function(request: web.Request) -> web.Response: ) req_ctx = AuthenticatedRequestContext.model_validate(request) - registered_function: RegisteredFunction = await _functions_service.get_function( + function = await _functions_service.get_function( app=request.app, function_id=function_id, user_id=req_ctx.user_id, product_name=req_ctx.product_name, ) - if ( - query_params.include_extras - and registered_function.function_class == FunctionClass.PROJECT - ): - function_with_extras = await _add_extras_to_project_function( - function=registered_function, - app=request.app, - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ) + access_rights = await _build_function_access_rights( + request.app, + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + function_id=function_id, + ) - return envelope_json_response( - TypeAdapter(RegisteredProjectFunctionGet).validate_python( - function_with_extras - ) - ) + extras = ( + await _build_function_extras(request.app, function=function) + if query_params.include_extras + else {} + ) return envelope_json_response( TypeAdapter(RegisteredFunctionGet).validate_python( - registered_function.model_dump(mode="json") + function.model_dump() | {"access_rights": access_rights, **extras} ) ) @@ -245,7 +307,7 @@ async def update_function(request: web.Request) -> web.Response: ) req_ctx = AuthenticatedRequestContext.model_validate(request) - updated_function = await _functions_service.update_function( + function = await _functions_service.update_function( request.app, user_id=req_ctx.user_id, product_name=req_ctx.product_name, @@ -253,26 +315,22 @@ async def update_function(request: web.Request) -> web.Response: function=function_update, ) - if ( - query_params.include_extras - and updated_function.function_class == FunctionClass.PROJECT - ): - function_with_extras = await _add_extras_to_project_function( - function=updated_function, - app=request.app, - user_id=req_ctx.user_id, - product_name=req_ctx.product_name, - ) + access_rights = await _build_function_access_rights( + request.app, + user_id=req_ctx.user_id, + product_name=req_ctx.product_name, + function_id=function_id, + ) - return envelope_json_response( - TypeAdapter(RegisteredProjectFunctionGet).validate_python( - function_with_extras - ) - ) + extras = ( + await _build_function_extras(request.app, function=function) + if query_params.include_extras + else {} + ) return envelope_json_response( TypeAdapter(RegisteredFunctionGet).validate_python( - updated_function.model_dump(mode="json") + function.model_dump() | {"access_rights": access_rights, **extras} ) ) diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/__init__.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_errors.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_errors.py new file mode 100644 index 00000000000..b1845546751 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_errors.py @@ -0,0 +1,5 @@ +from common_library.errors_classes import OsparcErrorMixin + + +class ServiceMetadataNotFoundError(OsparcErrorMixin, Exception): + msg_template = "Service metadata for key {key} and version {version} not found" diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_models.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_models.py new file mode 100644 index 00000000000..e9876b7770f --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_models.py @@ -0,0 +1,8 @@ +from models_library.services_types import ServiceKey, ServiceVersion +from pydantic import BaseModel, HttpUrl + + +class ServiceMetadata(BaseModel): + key: ServiceKey + version: ServiceVersion + thumbnail: HttpUrl | None diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_proxy.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_proxy.py new file mode 100644 index 00000000000..3c91989a28a --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_proxy.py @@ -0,0 +1,26 @@ +from collections.abc import Iterable + +from aiohttp import web +from models_library.services_types import ServiceKey, ServiceVersion + +from . import _repository +from ._models import ServiceMetadata + + +async def batch_get_service_metadata( + app: web.Application, + *, + keys_and_versions: Iterable[tuple[ServiceKey, ServiceVersion]], +) -> dict[tuple[ServiceKey, ServiceVersion], ServiceMetadata]: + return await _repository.batch_service_metadata( + app, keys_and_versions=keys_and_versions + ) + + +async def get_service_metadata( + app: web.Application, + *, + key: ServiceKey, + version: ServiceVersion, +) -> ServiceMetadata: + return await _repository.get_service_metadata(app, key=key, version=version) diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_repository.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_repository.py new file mode 100644 index 00000000000..99610f93287 --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/_repository.py @@ -0,0 +1,72 @@ +from collections.abc import Iterable + +from aiohttp import web +from models_library.services_types import ServiceKey, ServiceVersion +from simcore_postgres_database.models.services import ( + services_meta_data, +) +from simcore_postgres_database.utils_repos import pass_or_acquire_connection +from simcore_service_webserver.functions._services_metadata._errors import ( + ServiceMetadataNotFoundError, +) +from sqlalchemy import select, tuple_ +from sqlalchemy.ext.asyncio import AsyncConnection + +from ...db.plugin import get_asyncpg_engine +from ._models import ServiceMetadata + + +async def batch_service_metadata( + app: web.Application, + connection: AsyncConnection | None = None, + *, + keys_and_versions: Iterable[tuple[ServiceKey, ServiceVersion]], +) -> dict[tuple[ServiceKey, ServiceVersion], ServiceMetadata]: + keys_and_versions = list(keys_and_versions) + if not keys_and_versions: + return {} + + query = select( + services_meta_data.c.key, + services_meta_data.c.version, + services_meta_data.c.thumbnail, + ).where( + tuple_(services_meta_data.c.key, services_meta_data.c.version).in_( + keys_and_versions + ) + ) + + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.execute(query) + rows = result.fetchall() + + return { + (row.key, row.version): ServiceMetadata.model_validate( + row, from_attributes=True + ) + for row in rows + } + + +async def get_service_metadata( + app: web.Application, + connection: AsyncConnection | None = None, + *, + key: ServiceKey, + version: ServiceVersion, +) -> ServiceMetadata: + query = select( + services_meta_data.c.key, + services_meta_data.c.version, + services_meta_data.c.thumbnail, + ).where( + tuple_(services_meta_data.c.key, services_meta_data.c.version) == (key, version) + ) + + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.execute(query) + row = result.one_or_none() + if row is None: + raise ServiceMetadataNotFoundError(key=key, version=version) + + return ServiceMetadata.model_validate(row, from_attributes=True) diff --git a/services/web/server/src/simcore_service_webserver/functions/_services_metadata/proxy.py b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/proxy.py new file mode 100644 index 00000000000..3556c7c36ff --- /dev/null +++ b/services/web/server/src/simcore_service_webserver/functions/_services_metadata/proxy.py @@ -0,0 +1,9 @@ +from ._models import ServiceMetadata +from ._proxy import batch_get_service_metadata, get_service_metadata + +__all__: tuple[str, ...] = ( + "ServiceMetadata", + "batch_get_service_metadata", + "get_service_metadata", +) +# nopycln: file diff --git a/services/web/server/src/simcore_service_webserver/projects/_projects_repository.py b/services/web/server/src/simcore_service_webserver/projects/_projects_repository.py index 2231fd16482..83c357c3657 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_projects_repository.py +++ b/services/web/server/src/simcore_service_webserver/projects/_projects_repository.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from collections.abc import Callable, Iterable from datetime import datetime from typing import cast @@ -172,10 +172,10 @@ async def batch_get_projects( app: web.Application, connection: AsyncConnection | None = None, *, - project_uuids: list[ProjectID], -) -> list[ProjectDBGet]: + project_uuids: Iterable[ProjectID], +) -> dict[ProjectID, ProjectDBGet]: if not project_uuids: - return [] + return {} async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: query = ( sql.select(projects) @@ -183,7 +183,10 @@ async def batch_get_projects( .where(projects.c.uuid.in_([f"{uuid}" for uuid in project_uuids])) ) result = await conn.stream(query) - return [ProjectDBGet.model_validate(row) async for row in result] + return { + ProjectID(row.uuid): ProjectDBGet.model_validate(row) + async for row in result + } def _select_trashed_by_primary_gid_query() -> sql.Select: diff --git a/services/web/server/src/simcore_service_webserver/projects/_projects_service.py b/services/web/server/src/simcore_service_webserver/projects/_projects_service.py index 13c33253241..da9277430e3 100644 --- a/services/web/server/src/simcore_service_webserver/projects/_projects_service.py +++ b/services/web/server/src/simcore_service_webserver/projects/_projects_service.py @@ -12,7 +12,7 @@ import datetime import logging from collections import defaultdict -from collections.abc import Generator +from collections.abc import Generator, Iterable from contextlib import suppress from decimal import Decimal from pprint import pformat @@ -365,8 +365,8 @@ async def batch_get_project_name( async def batch_get_projects( app: web.Application, *, - project_uuids: list[ProjectID], -) -> list[ProjectDBGet]: + project_uuids: Iterable[ProjectID], +) -> dict[ProjectID, ProjectDBGet]: return await _projects_repository.batch_get_projects( app=app, project_uuids=project_uuids, diff --git a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rest.py b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rest.py index 24e412431eb..b6410071111 100644 --- a/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rest.py +++ b/services/web/server/tests/unit/with_dbs/04/functions_rpc/test_functions_controller_rest.py @@ -15,7 +15,7 @@ FunctionClass, JSONFunctionInputSchema, JSONFunctionOutputSchema, - RegisteredProjectFunctionGet, + RegisteredFunctionGet, ) from models_library.api_schemas_webserver.users import MyFunctionPermissionsGet from pydantic import TypeAdapter @@ -27,11 +27,11 @@ pytest_simcore_core_services_selection = ["rabbit"] -@pytest.fixture -def mock_function() -> dict[str, Any]: - return { - "title": "Test Function", - "description": "A test function", +@pytest.fixture(params=[FunctionClass.PROJECT, FunctionClass.SOLVER]) +def mocked_function(request) -> dict[str, Any]: + function_dict = { + "title": f"Test {request.param} Function", + "description": f"A test {request.param} function", "inputSchema": JSONFunctionInputSchema( schema_content={ "type": "object", @@ -44,11 +44,19 @@ def mock_function() -> dict[str, Any]: "properties": {"output1": {"type": "string"}}, }, ).model_dump(mode="json"), - "projectId": str(uuid4()), - "functionClass": FunctionClass.PROJECT, + "functionClass": request.param, "defaultInputs": None, } + match request.param: + case FunctionClass.PROJECT: + function_dict["projectId"] = f"{uuid4()}" + case FunctionClass.SOLVER: + function_dict["solverKey"] = "simcore/services/dynamic/test" + function_dict["solverVersion"] = "1.0.0" + + return function_dict + @pytest.mark.parametrize( "user_role,add_user_function_api_access_rights,expected_register,expected_get,expected_list,expected_update,expected_delete,expected_get2", @@ -79,7 +87,7 @@ def mock_function() -> dict[str, Any]: async def test_function_workflow( client: TestClient, logged_user: UserInfoDict, - mock_function: dict[str, Any], + mocked_function: dict[str, Any], expected_register: HTTPStatus, expected_get: HTTPStatus, expected_list: HTTPStatus, @@ -91,12 +99,12 @@ async def test_function_workflow( ) -> None: # Register a new function url = client.app.router["register_function"].url_for() - response = await client.post(url, json=mock_function) + response = await client.post(url, json=mocked_function) data, error = await assert_status(response, expected_status_code=expected_register) if error: returned_function_uid = uuid4() else: - returned_function = RegisteredProjectFunctionGet.model_validate(data) + returned_function = TypeAdapter(RegisteredFunctionGet).validate_python(data) assert returned_function.uid is not None returned_function_uid = returned_function.uid @@ -107,7 +115,7 @@ async def test_function_workflow( response = await client.get(url) data, error = await assert_status(response, expected_get) if not error: - retrieved_function = RegisteredProjectFunctionGet.model_validate(data) + retrieved_function = TypeAdapter(RegisteredFunctionGet).validate_python(data) assert retrieved_function.uid == returned_function.uid # List existing functions @@ -115,9 +123,9 @@ async def test_function_workflow( response = await client.get(url) data, error = await assert_status(response, expected_list) if not error: - retrieved_functions = TypeAdapter( - list[RegisteredProjectFunctionGet] - ).validate_python(data) + retrieved_functions = TypeAdapter(list[RegisteredFunctionGet]).validate_python( + data + ) assert len(retrieved_functions) == 1 assert retrieved_functions[0].uid == returned_function_uid @@ -132,7 +140,7 @@ async def test_function_workflow( ) data, error = await assert_status(response, expected_update) if not error: - updated_function = RegisteredProjectFunctionGet.model_validate(data) + updated_function = TypeAdapter(RegisteredFunctionGet).validate_python(data) assert updated_function.title == new_title assert updated_function.description == new_description