Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Annotated, TypeAlias

from pydantic import Field
from pydantic import Field, HttpUrl

from ..functions import (
Function,
Expand Down Expand Up @@ -116,6 +116,10 @@

class RegisteredSolverFunctionGet(RegisteredSolverFunction, OutputSchema):
uid: Annotated[FunctionID, Field(alias="uuid")]
created_at: Annotated[datetime.datetime, Field(alias="creationDate")]
modified_at: Annotated[datetime.datetime, Field(alias="lastChangeDate")]
access_rights: FunctionAccessRights | None = None
thumbnail: HttpUrl | None = None


class RegisteredProjectFunctionGet(RegisteredProjectFunction, OutputSchema):
Expand All @@ -124,7 +128,7 @@ class RegisteredProjectFunctionGet(RegisteredProjectFunction, OutputSchema):
created_at: Annotated[datetime.datetime, Field(alias="creationDate")]
modified_at: Annotated[datetime.datetime, Field(alias="lastChangeDate")]
access_rights: FunctionAccessRights | None = None
thumbnail: str | None = None
thumbnail: HttpUrl | None = None


class SolverFunctionToRegister(SolverFunction, InputSchema): ...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import logging
from typing import Any

from aiohttp import web
from models_library.api_schemas_webserver.functions import (
Function,
FunctionToRegister,
RegisteredFunction,
RegisteredFunctionGet,
RegisteredFunctionUpdate,
RegisteredProjectFunctionGet,
)
from models_library.api_schemas_webserver.users import MyFunctionPermissionsGet
from models_library.functions import FunctionClass, RegisteredProjectFunction
from models_library.functions import (
FunctionAccessRights,
FunctionClass,
FunctionID,
RegisteredProjectFunction,
RegisteredSolverFunction,
)
from models_library.products import ProductName
from models_library.projects import ProjectID
from models_library.rest_pagination import Page
from models_library.rest_pagination_utils import paginate_data
from models_library.services_types import ServiceKey, ServiceVersion
from models_library.users import UserID
from pydantic import TypeAdapter
from servicelib.aiohttp import status
Expand All @@ -29,6 +39,8 @@
from ...security.decorators import permission_required
from ...utils_aiohttp import create_json_response_from_page, envelope_json_response
from .. import _functions_service
from .._services_metadata import proxy as _services_metadata_proxy
from .._services_metadata.proxy import ServiceMetadata
from ._functions_rest_exceptions import handle_rest_requests_exceptions
from ._functions_rest_schemas import (
FunctionGetQueryParams,
Expand All @@ -39,30 +51,81 @@
routes = web.RouteTableDef()


async def _add_extras_to_project_function(
function: RegisteredProjectFunction,
logger = logging.getLogger(__name__)


async def _build_function_access_rights(
app: web.Application,
user_id: UserID,
product_name: ProductName,
) -> dict:
assert isinstance(function, RegisteredProjectFunction) # nosec
function_id: FunctionID,
) -> FunctionAccessRights:
access_rights = await _functions_service.get_function_user_permissions(
app=app,
user_id=user_id,
product_name=product_name,
function_id=function_id,
)

return FunctionAccessRights(
read=access_rights.read,
write=access_rights.write,
execute=access_rights.execute,
)


async def _build_project_function_extras_dict(
app: web.Application,
*,
user_id: UserID,
function: RegisteredProjectFunction,
) -> dict[str, Any]:
project_dict = await _projects_service.get_project_for_user(
app=app,
project_uuid=f"{function.project_id}",
user_id=user_id,
)

function_with_extras = function.model_dump(mode="json") | {
"access_rights": await _functions_service.get_function_user_permissions(
app,
user_id=user_id,
product_name=product_name,
function_id=function.uid,
return {
"thumbnail": project_dict.get("thumbnail") or None,
}


async def _build_solver_function_extras_dict(
app: web.Application,
*,
function: RegisteredSolverFunction,
) -> dict[str, Any]:
services_metadata = await _services_metadata_proxy.get_service_metadata(
app,
key=function.solver_key,
version=function.solver_version,
)
return {
"thumbnail": (
f"{services_metadata.thumbnail}" if services_metadata.thumbnail else None
),
"thumbnail": project_dict.get("thumbnail", None),
}
return function_with_extras


async def _build_function_extras(
app: web.Application, user_id: UserID, *, function: RegisteredFunction
) -> dict[str, Any]:
extras: dict[str, Any] = {}
if function.function_class == FunctionClass.PROJECT:
assert isinstance(function, RegisteredProjectFunction)
extras |= await _build_project_function_extras_dict(
function=function,
app=app,
user_id=user_id,
)
elif function.function_class == FunctionClass.SOLVER:
assert isinstance(function, RegisteredSolverFunction)
extras |= await _build_solver_function_extras_dict(
app,
function=function,
)
return extras


@routes.post(f"/{VTAG}/functions", name="register_function")
Expand Down Expand Up @@ -117,55 +180,61 @@ async def list_functions(request: web.Request) -> web.Response:
)

chunk: list[RegisteredFunctionGet] = []
projects_map: dict[str, ProjectDBGet | None] = (

projects_cache: dict[ProjectID, ProjectDBGet] = {}
service_metadata_cache: dict[tuple[ServiceKey, ServiceVersion], ServiceMetadata] = (
{}
) # ProjectDBGet has to be renamed at some point!
)

if query_params.include_extras:
project_ids = []
for function in functions:
if function.function_class == FunctionClass.PROJECT:
assert isinstance(function, RegisteredProjectFunction)
project_ids.append(function.project_id)

projects_map = {
f"{p.uuid}": p
for p in await _projects_service.batch_get_projects(
if any(
function.function_class == FunctionClass.PROJECT for function in functions
):
project_uuids = [
function.project_id
for function in functions
if function.function_class == FunctionClass.PROJECT
]
projects_cache |= await _projects_service.batch_get_projects(
request.app,
project_uuids=project_ids,
project_uuids=project_uuids,
)
}

for function in functions:
if (
query_params.include_extras
and function.function_class == FunctionClass.PROJECT
if any(
function.function_class == FunctionClass.SOLVER for function in functions
):
assert isinstance(function, RegisteredProjectFunction) # nosec
if project := projects_map.get(f"{function.project_id}"):
chunk.append(
TypeAdapter(RegisteredProjectFunctionGet).validate_python(
function.model_dump(mode="json")
| {
"access_rights": await _functions_service.get_function_user_permissions(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
function_id=function.uid,
),
"thumbnail": (
f"{project.thumbnail}" if project.thumbnail else None
),
}
)
)
else:
chunk.append(
TypeAdapter(RegisteredFunctionGet).validate_python(
function.model_dump(mode="json")
service_keys_and_versions = {
(function.solver_key, function.solver_version)
for function in functions
if function.function_class == FunctionClass.SOLVER
}
service_metadata_cache |= (
await _services_metadata_proxy.batch_get_service_metadata(
app=request.app, keys_and_versions=service_keys_and_versions
)
)

for function in functions:
access_rights = await _build_function_access_rights(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
function_id=function.uid,
)

extras = (
await _build_function_extras(
request.app, req_ctx.user_id, function=function
)
if query_params.include_extras
else {}
)

chunk.append(
TypeAdapter(RegisteredFunctionGet).validate_python(
function.model_dump() | {"access_rights": access_rights, **extras}
)
)

page = Page[RegisteredFunctionGet].model_validate(
paginate_data(
chunk=chunk,
Expand Down Expand Up @@ -194,33 +263,29 @@ async def get_function(request: web.Request) -> web.Response:
)

req_ctx = AuthenticatedRequestContext.model_validate(request)
registered_function: RegisteredFunction = await _functions_service.get_function(
function = await _functions_service.get_function(
app=request.app,
function_id=function_id,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
)

if (
query_params.include_extras
and registered_function.function_class == FunctionClass.PROJECT
):
function_with_extras = await _add_extras_to_project_function(
function=registered_function,
app=request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
)
access_rights = await _build_function_access_rights(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
function_id=function_id,
)

return envelope_json_response(
TypeAdapter(RegisteredProjectFunctionGet).validate_python(
function_with_extras
)
)
extras = (
await _build_function_extras(request.app, req_ctx.user_id, function=function)
if query_params.include_extras
else {}
)

return envelope_json_response(
TypeAdapter(RegisteredFunctionGet).validate_python(
registered_function.model_dump(mode="json")
function.model_dump() | {"access_rights": access_rights, **extras}
)
)

Expand All @@ -245,34 +310,30 @@ async def update_function(request: web.Request) -> web.Response:
)
req_ctx = AuthenticatedRequestContext.model_validate(request)

updated_function = await _functions_service.update_function(
function = await _functions_service.update_function(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
function_id=function_id,
function=function_update,
)

if (
query_params.include_extras
and updated_function.function_class == FunctionClass.PROJECT
):
function_with_extras = await _add_extras_to_project_function(
function=updated_function,
app=request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
)
access_rights = await _build_function_access_rights(
request.app,
user_id=req_ctx.user_id,
product_name=req_ctx.product_name,
function_id=function_id,
)

return envelope_json_response(
TypeAdapter(RegisteredProjectFunctionGet).validate_python(
function_with_extras
)
)
extras = (
await _build_function_extras(request.app, req_ctx.user_id, function=function)
if query_params.include_extras
else {}
)

return envelope_json_response(
TypeAdapter(RegisteredFunctionGet).validate_python(
updated_function.model_dump(mode="json")
function.model_dump() | {"access_rights": access_rights, **extras}
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from common_library.errors_classes import OsparcErrorMixin


class ServiceMetadataNotFoundError(OsparcErrorMixin, Exception):
msg_template = "Service metadata for key {key} and version {version} not found"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from models_library.services_types import ServiceKey, ServiceVersion
from pydantic import BaseModel, HttpUrl


class ServiceMetadata(BaseModel):
key: ServiceKey
version: ServiceVersion
thumbnail: HttpUrl | None
Loading
Loading