diff --git a/aiida_restapi/__init__.py b/aiida_restapi/__init__.py index d3618b6e..27f6e987 100644 --- a/aiida_restapi/__init__.py +++ b/aiida_restapi/__init__.py @@ -1,5 +1,3 @@ -"""AiiDA REST API for data queries and workflow managment.""" +"""AiiDA REST API for data queries and workflow management.""" __version__ = '0.1.0a1' - -from .main import app # noqa: F401 diff --git a/aiida_restapi/routers/comments.py b/aiida_restapi/cli/__init__.py similarity index 100% rename from aiida_restapi/routers/comments.py rename to aiida_restapi/cli/__init__.py diff --git a/aiida_restapi/cli/main.py b/aiida_restapi/cli/main.py new file mode 100644 index 00000000..f72e2a37 --- /dev/null +++ b/aiida_restapi/cli/main.py @@ -0,0 +1,31 @@ +import os + +import click +import uvicorn + + +@click.group() +def cli() -> None: + """AiiDA REST API management CLI.""" + + +@cli.command() +@click.option('--host', default='127.0.0.1', show_default=True) +@click.option('--port', default=8000, show_default=True, type=int) +@click.option('--read-only', is_flag=True) +@click.option('--watch', is_flag=True) +def start(read_only: bool, watch: bool, host: str, port: int) -> None: + """Start the AiiDA REST API service.""" + + os.environ['AIIDA_RESTAPI_READ_ONLY'] = '1' if read_only else '0' + + click.echo(f'Starting REST API (read_only={read_only}, watch={watch}) on {host}:{port}') + + uvicorn.run( + 'aiida_restapi.main:create_app', + host=host, + port=port, + reload=watch, + reload_dirs=['aiida_restapi'], + factory=True, + ) diff --git a/aiida_restapi/common/__init__.py b/aiida_restapi/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_restapi/common/pagination.py b/aiida_restapi/common/pagination.py new file mode 100644 index 00000000..2c0f57b0 --- /dev/null +++ b/aiida_restapi/common/pagination.py @@ -0,0 +1,19 @@ +"""Pagination utilities.""" + +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida.orm import Entity + +ResultType = t.TypeVar('ResultType', bound=Entity.Model) + +__all__ = ('PaginatedResults',) + + +class PaginatedResults(pdt.BaseModel, t.Generic[ResultType]): + total: int + page: int + page_size: int + results: list[ResultType] diff --git a/aiida_restapi/common/query.py b/aiida_restapi/common/query.py new file mode 100644 index 00000000..8f496635 --- /dev/null +++ b/aiida_restapi/common/query.py @@ -0,0 +1,90 @@ +"""REST API query utilities.""" + +from __future__ import annotations + +import json +import typing as t + +import pydantic as pdt +from fastapi import HTTPException, Query + + +class QueryParams(pdt.BaseModel): + filters: dict[str, t.Any] = pdt.Field( + default_factory=dict, + description='AiiDA QueryBuilder filters', + examples=[ + {'node_type': {'==': 'data.core.int.Int.'}}, + {'attributes.value': {'>': 42}}, + ], + ) + order_by: str | list[str] | dict[str, t.Any] | None = pdt.Field( + None, + description='Fields to sort by', + examples=[ + {'attributes.value': 'desc'}, + ], + ) + page_size: pdt.PositiveInt = pdt.Field( + 10, + description='Number of results per page', + examples=[10], + ) + page: pdt.PositiveInt = pdt.Field( + 1, + description='Page number', + examples=[1], + ) + + +def query_params( + filters: str | None = Query( + None, + description='AiiDA QueryBuilder filters as JSON string', + ), + order_by: str | None = Query( + None, + description='Comma-separated list of fields to sort by', + ), + page_size: pdt.PositiveInt = Query( + 10, + description='Number of results per page', + ), + page: pdt.PositiveInt = Query( + 1, + description='Page number', + ), +) -> QueryParams: + """Parse query parameters into a structured object. + + :param filters: AiiDA QueryBuilder filters as JSON string. + :param order_by: Comma-separated string of fields to sort by. + :param page_size: Number of results per page. + :param page: Page number. + :return: Structured query parameters. + :raises HTTPException: If filters cannot be parsed as JSON. + """ + query_filters: dict[str, t.Any] = {} + query_order_by: str | list[str] | dict[str, t.Any] | None = None + if filters: + try: + query_filters = json.loads(filters) + except Exception as exception: + raise HTTPException( + status_code=400, + detail=f'Could not parse filters as JSON: {exception}', + ) from exception + if order_by: + try: + query_order_by = json.loads(order_by) + except Exception as exception: + raise HTTPException( + status_code=400, + detail=f'Could not parse order_by as JSON: {exception}', + ) from exception + return QueryParams( + filters=query_filters, + order_by=query_order_by, + page_size=page_size, + page=page, + ) diff --git a/aiida_restapi/common/types.py b/aiida_restapi/common/types.py new file mode 100644 index 00000000..0b4c2014 --- /dev/null +++ b/aiida_restapi/common/types.py @@ -0,0 +1,13 @@ +"""Common type variables.""" + +from __future__ import annotations + +import typing as t + +from aiida import orm + +EntityType = t.TypeVar('EntityType', bound='orm.Entity') +EntityModelType = t.TypeVar('EntityModelType', bound='orm.Entity.Model') + +NodeType = t.TypeVar('NodeType', bound='orm.Node') +NodeModelType = t.TypeVar('NodeModelType', bound='orm.Node.Model') diff --git a/aiida_restapi/config.py b/aiida_restapi/config.py index 0316437d..5145211c 100644 --- a/aiida_restapi/config.py +++ b/aiida_restapi/config.py @@ -1,5 +1,7 @@ """Configuration of API""" +from aiida_restapi import __version__ + # to get a string like this run: # openssl rand -hex 32 SECRET_KEY = '09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7' @@ -18,3 +20,8 @@ 'disabled': False, } } + +API_CONFIG = { + 'PREFIX': '/api/v0', + 'VERSION': __version__, +} diff --git a/aiida_restapi/aiida_db_mappings.py b/aiida_restapi/graphql/aiida_db_mappings.py similarity index 100% rename from aiida_restapi/aiida_db_mappings.py rename to aiida_restapi/graphql/aiida_db_mappings.py diff --git a/aiida_restapi/graphql/comments.py b/aiida_restapi/graphql/comments.py index 49467ac7..c0003718 100644 --- a/aiida_restapi/graphql/comments.py +++ b/aiida_restapi/graphql/comments.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Comment -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .orm_factories import ( ENTITY_DICT_TYPE, diff --git a/aiida_restapi/graphql/computers.py b/aiida_restapi/graphql/computers.py index bc877186..520e1522 100644 --- a/aiida_restapi/graphql/computers.py +++ b/aiida_restapi/graphql/computers.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Computer -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.plugins import QueryPlugin from .nodes import NodesQuery diff --git a/aiida_restapi/filter_syntax.py b/aiida_restapi/graphql/filter_syntax.py similarity index 98% rename from aiida_restapi/filter_syntax.py rename to aiida_restapi/graphql/filter_syntax.py index ba7010f2..6861fe30 100644 --- a/aiida_restapi/filter_syntax.py +++ b/aiida_restapi/graphql/filter_syntax.py @@ -13,8 +13,8 @@ from lark import Lark, Token, Tree -from . import static -from .utils import parse_date +from .. import static +from ..utils import parse_date FILTER_GRAMMAR = resources.open_text(static, 'filter_grammar.lark') diff --git a/aiida_restapi/graphql/groups.py b/aiida_restapi/graphql/groups.py index 21f7b121..b811c825 100644 --- a/aiida_restapi/graphql/groups.py +++ b/aiida_restapi/graphql/groups.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Group -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.nodes import NodesQuery from aiida_restapi.graphql.plugins import QueryPlugin diff --git a/aiida_restapi/graphql/logs.py b/aiida_restapi/graphql/logs.py index eebd2e6f..aaaace5e 100644 --- a/aiida_restapi/graphql/logs.py +++ b/aiida_restapi/graphql/logs.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Log -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .orm_factories import ( ENTITY_DICT_TYPE, diff --git a/aiida_restapi/graphql/nodes.py b/aiida_restapi/graphql/nodes.py index b0e0fae1..8766c7da 100644 --- a/aiida_restapi/graphql/nodes.py +++ b/aiida_restapi/graphql/nodes.py @@ -6,7 +6,7 @@ import graphene as gr from aiida import orm -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.plugins import QueryPlugin from .comments import CommentsQuery diff --git a/aiida_restapi/graphql/orm_factories.py b/aiida_restapi/graphql/orm_factories.py index fade6fa0..1c091216 100644 --- a/aiida_restapi/graphql/orm_factories.py +++ b/aiida_restapi/graphql/orm_factories.py @@ -12,7 +12,7 @@ from graphql import GraphQLError from pydantic import Json -from aiida_restapi.aiida_db_mappings import ORM_MAPPING, get_model_from_orm +from aiida_restapi.graphql.aiida_db_mappings import ORM_MAPPING, get_model_from_orm from .config import ENTITY_LIMIT from .utils import JSON, selected_field_names_naive diff --git a/aiida_restapi/graphql/users.py b/aiida_restapi/graphql/users.py index cc0a8077..473d907c 100644 --- a/aiida_restapi/graphql/users.py +++ b/aiida_restapi/graphql/users.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import User -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .nodes import NodesQuery from .orm_factories import ( diff --git a/aiida_restapi/identifiers.py b/aiida_restapi/identifiers.py deleted file mode 100644 index e00a4668..00000000 --- a/aiida_restapi/identifiers.py +++ /dev/null @@ -1,170 +0,0 @@ -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions to work with node "full types" which identify node types. - -A node's `full_type` is defined as a string that uniquely defines the node type. A valid `full_type` is constructed by -concatenating the `node_type` and `process_type` of a node with the `FULL_TYPE_CONCATENATOR`. Each segment of the full -type can optionally be terminated by a single `LIKE_OPERATOR_CHARACTER` to indicate that the `node_type` or -`process_type` should start with that value but can be followed by any amount of other characters. A full type is -invalid if it does not contain exactly one `FULL_TYPE_CONCATENATOR` character. Additionally, each segment can contain -at most one occurrence of the `LIKE_OPERATOR_CHARACTER` and it has to be at the end of the segment. - -Examples of valid full types: - - 'data.bool.Bool.|' - 'process.calculation.calcfunction.%|%' - 'process.calculation.calcjob.CalcJobNode.|aiida.calculations:arithmetic.add' - 'process.calculation.calcfunction.CalcFunctionNode.|aiida.workflows:codtools.primitive_structure_from_cif' - -Examples of invalid full types: - - 'data.bool' # Only a single segment without concatenator - 'data.|bool.Bool.|process.' # More than one concatenator - 'process.calculation%.calcfunction.|aiida.calculations:arithmetic.add' # Like operator not at end of segment - 'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment - -""" - -from typing import Any - -from aiida.common.escaping import escape_for_sql_like - -FULL_TYPE_CONCATENATOR = '|' -LIKE_OPERATOR_CHARACTER = '%' -DEFAULT_NAMESPACE_LABEL = '~no-entry-point~' - - -def validate_full_type(full_type: str) -> None: - """Validate that the `full_type` is a valid full type unique node identifier. - - :param full_type: a `Node` full type - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - """ - from aiida.common.lang import type_check - - type_check(full_type, str) - - if FULL_TYPE_CONCATENATOR not in full_type: - raise ValueError( - f'full type `{full_type}` does not include the required concatenator symbol `{FULL_TYPE_CONCATENATOR}`.' - ) - elif full_type.count(FULL_TYPE_CONCATENATOR) > 1: - raise ValueError( - f'full type `{full_type}` includes the concatenator symbol `{FULL_TYPE_CONCATENATOR}` more than once.' - ) - - -def construct_full_type(node_type: str, process_type: str) -> str: - """Return the full type, which fully identifies the type of any `Node` with the given `node_type` and - `process_type`. - - :param node_type: the `node_type` of the `Node` - :param process_type: the `process_type` of the `Node` - :return: the full type, which is a unique identifier - """ - if node_type is None: - node_type = '' - - if process_type is None: - process_type = '' - - return f'{node_type}{FULL_TYPE_CONCATENATOR}{process_type}' - - -def get_full_type_filters(full_type: str) -> dict[str, Any]: - """Return the `QueryBuilder` filters that will return all `Nodes` identified by the given `full_type`. - - :param full_type: the `full_type` node type identifier - :return: dictionary of filters to be passed for the `filters` keyword in `QueryBuilder.append` - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - """ - validate_full_type(full_type) - - filters: dict[str, Any] = {} - node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) - - for entry in (node_type, process_type): - if entry.count(LIKE_OPERATOR_CHARACTER) > 1: - raise ValueError(f'full type component `{entry}` contained more than one like-operator character') - - if LIKE_OPERATOR_CHARACTER in entry and entry[-1] != LIKE_OPERATOR_CHARACTER: - raise ValueError(f'like-operator character in full type component `{entry}` is not at the end') - - if LIKE_OPERATOR_CHARACTER in node_type: - # Remove the trailing `LIKE_OPERATOR_CHARACTER`, escape the string and reattach the character - node_type = node_type[:-1] - node_type = escape_for_sql_like(node_type) + LIKE_OPERATOR_CHARACTER - filters['node_type'] = {'like': node_type} - else: - filters['node_type'] = {'==': node_type} - - if LIKE_OPERATOR_CHARACTER in process_type: - # Remove the trailing `LIKE_OPERATOR_CHARACTER` () - # If that was the only specification, just ignore this filter (looking for any process_type) - # If there was more: escape the string and reattach the character - process_type = process_type[:-1] - if process_type: - process_type = escape_for_sql_like(process_type) + LIKE_OPERATOR_CHARACTER - filters['process_type'] = {'like': process_type} - elif process_type: - filters['process_type'] = {'==': process_type} - else: - # A `process_type=''` is used to represents both `process_type='' and `process_type=None`. - # This is because there is no simple way to single out null `process_types`, and therefore - # we consider them together with empty-string process_types. - # Moreover, the existence of both is most likely a bug of migrations and thus both share - # this same "erroneous" origin. - filters['process_type'] = {'or': [{'==': ''}, {'==': None}]} - - return filters - - -def load_entry_point_from_full_type(full_type: str) -> Any: - """Return the loaded entry point for the given `full_type` unique node type identifier. - - :param full_type: the `full_type` unique node type identifier - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - :raises `~aiida.common.exceptions.EntryPointError`: if the corresponding entry point cannot be loaded - """ - from aiida.common import EntryPointError - from aiida.common.utils import strip_prefix - from aiida.plugins.entry_point import ( - is_valid_entry_point_string, - load_entry_point, - load_entry_point_from_string, - ) - - data_prefix = 'data.' - - validate_full_type(full_type) - - node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) - - if is_valid_entry_point_string(process_type): - try: - return load_entry_point_from_string(process_type) - except EntryPointError: - raise EntryPointError(f'could not load entry point `{process_type}`') - - elif node_type.startswith(data_prefix): - base_name = strip_prefix(node_type, data_prefix) - entry_point_name = base_name.rsplit('.', 2)[0] - - try: - return load_entry_point('aiida.data', entry_point_name) - except EntryPointError: - raise EntryPointError(f'could not load entry point `{process_type}`') - - # Here we are dealing with a `ProcessNode` with a `process_type` that is not an entry point string. - # Which means it is most likely a full module path (the fallback option) and we cannot necessarily load the - # class from this. We could try with `importlib` but not sure that we should - raise EntryPointError('entry point of the given full type cannot be loaded') diff --git a/aiida_restapi/main.py b/aiida_restapi/main.py index 5e75b8c3..d85bf9ca 100644 --- a/aiida_restapi/main.py +++ b/aiida_restapi/main.py @@ -1,16 +1,45 @@ """Declaration of FastAPI application.""" -from fastapi import FastAPI +import os +from fastapi import APIRouter, FastAPI +from fastapi.responses import RedirectResponse + +from aiida_restapi.config import API_CONFIG from aiida_restapi.graphql import main -from aiida_restapi.routers import auth, computers, daemon, groups, nodes, process, users - -app = FastAPI() -app.include_router(auth.router) -app.include_router(computers.router) -app.include_router(daemon.router) -app.include_router(nodes.router) -app.include_router(groups.router) -app.include_router(users.router) -app.include_router(process.router) -app.add_route('/graphql', main.app, name='graphql') +from aiida_restapi.routers import auth, computers, daemon, groups, nodes, querybuilder, server, submit, users + + +def create_app() -> FastAPI: + """Create the FastAPI application and include the routers. + + :return: The FastAPI application. + :rtype: FastAPI + """ + + read_only = os.getenv('AIIDA_RESTAPI_READ_ONLY') == '1' + + app = FastAPI() + + api_router = APIRouter(prefix=API_CONFIG['PREFIX']) + + api_router.add_route( + '/', + lambda _: RedirectResponse(url=api_router.url_path_for('endpoints')), + ) + + for module in (auth, server, users, computers, groups, nodes, querybuilder, submit, daemon): + if read_router := getattr(module, 'read_router', None): + api_router.include_router(read_router) + if not read_only and (write_router := getattr(module, 'write_router', None)): + api_router.include_router(write_router) + + api_router.add_route( + '/graphql', + main.app, + methods=['POST'], + ) + + app.include_router(api_router) + + return app diff --git a/aiida_restapi/models.py b/aiida_restapi/models.py deleted file mode 100644 index ecd291ec..00000000 --- a/aiida_restapi/models.py +++ /dev/null @@ -1,340 +0,0 @@ -"""Schemas for AiiDA REST API. - -Models in this module mirror those in -`aiida.backends.djsite.db.models` and `aiida.backends.sqlalchemy.models` -""" -# pylint: disable=too-few-public-methods - -import inspect -from datetime import datetime -from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar -from uuid import UUID - -from aiida import orm -from fastapi import Form -from pydantic import BaseModel, ConfigDict, Field - -# Template type for subclasses of `AiidaModel` -ModelType = TypeVar('ModelType', bound='AiidaModel') - - -def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: - """ - Adds an as_form class method to decorated models. The as_form class method - can be used with FastAPI endpoints - - Note: Taken from https://github.com/tiangolo/fastapi/issues/2387 - """ - new_parameters = [] - - for field_name, model_field in cls.model_fields.items(): - new_parameters.append( - inspect.Parameter( - name=field_name, - kind=inspect.Parameter.POSITIONAL_ONLY, - default=Form(...) if model_field.is_required() else Form(model_field.default), - annotation=model_field.annotation, - ) - ) - - async def as_form_func(**data: Dict[str, Any]) -> Any: - return cls(**data) - - sig = inspect.signature(as_form_func) - sig = sig.replace(parameters=new_parameters) - as_form_func.__signature__ = sig # type: ignore - setattr(cls, 'as_form', as_form_func) - return cls - - -class AiidaModel(BaseModel): - """A mapping of an AiiDA entity to a pydantic model.""" - - _orm_entity: ClassVar[Type[orm.entities.Entity]] = orm.entities.Entity - model_config = ConfigDict(from_attributes=True, extra='forbid') - - @classmethod - def get_projectable_properties(cls) -> List[str]: - """Return projectable properties.""" - return list(cls.schema()['properties'].keys()) - - @classmethod - def get_entities( - cls: Type[ModelType], - *, - page_size: Optional[int] = None, - page: int = 0, - project: Optional[List[str]] = None, - order_by: Optional[List[str]] = None, - ) -> List[ModelType]: - """Return a list of entities (with pagination). - - :param project: properties to project (default: all available) - :param page_size: the page size (default: infinite) - :param page: the page to return, if page_size set - """ - if project is None: - project = cls.get_projectable_properties() - else: - assert set(cls.get_projectable_properties()).issuperset( - project - ), f'projection not subset of projectable properties: {project!r}' - query = orm.QueryBuilder().append(cls._orm_entity, tag='fields', project=project) - if page_size is not None: - query.offset(page_size * (page - 1)) - query.limit(page_size) - if order_by is not None: - assert set(cls.get_projectable_properties()).issuperset( - order_by - ), f'order_by not subset of projectable properties: {project!r}' - query.order_by({'fields': order_by}) - return [cls(**result['fields']) for result in query.dict()] - - -class Comment(AiidaModel): - """AiiDA Comment model.""" - - _orm_entity = orm.Comment - - id: Optional[int] = Field(None, description='Unique comment id (pk)') - uuid: str = Field(description='Unique comment uuid') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - content: Optional[str] = Field(None, description='Comment content') - dbnode_id: Optional[int] = Field(None, description='Unique node id (pk)') - user_id: Optional[int] = Field(None, description='Unique user id (pk)') - - -class User(AiidaModel): - """AiiDA User model.""" - - _orm_entity = orm.User - model_config = ConfigDict(extra='allow') - - id: Optional[int] = Field(None, description='Unique user id (pk)') - email: str = Field(description='Email address of the user') - first_name: Optional[str] = Field(None, description='First name of the user') - last_name: Optional[str] = Field(None, description='Last name of the user') - institution: Optional[str] = Field(None, description='Host institution or workplace of the user') - - -class Computer(AiidaModel): - """AiiDA Computer Model.""" - - _orm_entity = orm.Computer - - id: Optional[int] = Field(None, description='Unique computer id (pk)') - uuid: Optional[str] = Field(None, description='Unique id for computer') - label: str = Field(description='Used to identify a computer. Must be unique') - hostname: Optional[str] = Field(None, description='Label that identifies the computer within the network') - scheduler_type: Optional[str] = Field( - None, - description='The scheduler (and plugin) that the computer uses to manage jobs', - ) - transport_type: Optional[str] = Field( - None, - description='The transport (and plugin) \ - required to copy files and communicate to and from the computer', - ) - metadata: Optional[dict] = Field( - None, - description='General settings for these communication and management protocols', - ) - - description: Optional[str] = Field(None, description='Description of node') - - -class Node(AiidaModel): - """AiiDA Node Model.""" - - _orm_entity = orm.Node - - id: Optional[int] = Field(None, description='Unique id (pk)') - uuid: Optional[UUID] = Field(None, description='Unique uuid') - node_type: Optional[str] = Field(None, description='Node type') - process_type: Optional[str] = Field(None, description='Process type') - label: str = Field(description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - repository_metadata: Optional[Dict] = Field( - None, - description='Metadata about file repository associated with this node', - ) - - -@as_form -class Node_Post(AiidaModel): - """AiiDA model for posting Nodes.""" - - entry_point: str = Field(description='Entry_point') - process_type: Optional[str] = Field(None, description='Process type') - label: Optional[str] = Field(None, description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - - @classmethod - def create_new_node( - cls: Type[ModelType], - orm_class: orm.Node, - node_dict: dict, - ) -> orm.Node: - """Create and Store new Node""" - attributes = node_dict.pop('attributes', {}) - extras = node_dict.pop('extras', {}) - repository_metadata = node_dict.pop('repository_metadata', {}) - - if issubclass(orm_class, orm.BaseType): - orm_object = orm_class( - attributes['value'], - **node_dict, - ) - elif issubclass(orm_class, orm.Dict): - orm_object = orm_class( - dict=attributes, - **node_dict, - ) - elif issubclass(orm_class, orm.InstalledCode): - orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), - filepath_executable=attributes['filepath_executable'], - ) - orm_object.label = node_dict.get('label') - elif issubclass(orm_class, orm.PortableCode): - orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), - filepath_executable=attributes['filepath_executable'], - filepath_files=attributes['filepath_files'], - ) - orm_object.label = node_dict.get('label') - else: - orm_object = orm_class(**node_dict) - orm_object.base.attributes.set_many(attributes) - - orm_object.base.extras.set_many(extras) - orm_object.base.repository.repository_metadata = repository_metadata - orm_object.store() - return orm_object - - @classmethod - def create_new_node_with_file( - cls: Type[ModelType], - orm_class: orm.Node, - node_dict: dict, - file: Path, - ) -> orm.Node: - """Create and Store new Node with file""" - attributes = node_dict.pop('attributes', {}) - extras = node_dict.pop('extras', {}) - - orm_object = orm_class(file=file, **node_dict, **attributes) - - orm_object.base.extras.set_many(extras) - orm_object.store() - return orm_object - - -class Group(AiidaModel): - """AiiDA Group model.""" - - _orm_entity = orm.Group - - id: int = Field(description='Unique id (pk)') - uuid: UUID = Field(description='Universally unique id') - label: str = Field(description='Label of group') - type_string: str = Field(description='type of the group') - description: Optional[str] = Field(None, description='Description of group') - extras: Optional[Dict] = Field(None, description='extra data about for the group') - time: datetime = Field(description='Created time') - user_id: int = Field(description='Created by user id (pk)') - - @classmethod - def from_orm(cls, orm_entity: orm.Group) -> orm.Group: - """Convert from ORM object. - - Args: - obj: The ORM entity to convert - - Returns: - The converted Group object - """ - query = ( - orm.QueryBuilder() - .append( - cls._orm_entity, - filters={'pk': orm_entity.id}, - tag='fields', - project=['user_id', 'time'], - ) - .limit(1) - ) - orm_entity.user_id = query.dict()[0]['fields']['user_id'] - - return super().from_orm(orm_entity) - - -class Group_Post(AiidaModel): - """AiiDA Group Post model.""" - - _orm_entity = orm.Group - - label: str = Field(description='Used to access the group. Must be unique.') - type_string: Optional[str] = Field(None, description='Type of the group') - description: Optional[str] = Field(None, description='Short description of the group.') - - -class Process(AiidaModel): - """AiiDA Process Model""" - - _orm_entity = orm.ProcessNode - - id: Optional[int] = Field(None, description='Unique id (pk)') - uuid: Optional[UUID] = Field(None, description='Universally unique identifier') - node_type: Optional[str] = Field(None, description='Node type') - process_type: Optional[str] = Field(None, description='Process type') - label: str = Field(description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - repository_metadata: Optional[Dict] = Field( - None, - description='Metadata about file repository associated with this node', - ) - - -class Process_Post(AiidaModel): - """AiiDA Process Post Model""" - - label: str = Field(description='Label of node') - inputs: dict = Field(description='Input parmeters') - process_entry_point: str = Field(description='Entry Point for process') diff --git a/aiida_restapi/models/__init__.py b/aiida_restapi/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_restapi/models/node.py b/aiida_restapi/models/node.py new file mode 100644 index 00000000..0f8096ae --- /dev/null +++ b/aiida_restapi/models/node.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida.orm import Node +from aiida.plugins import get_entry_points +from importlib_metadata import EntryPoint + + +class NodeStatistics(pdt.BaseModel): + """Pydantic model representing node statistics.""" + + total: int = pdt.Field( + description='Total number of nodes.', + examples=[47], + ) + types: dict[str, int] = pdt.Field( + description='Number of nodes by type.', + examples=[ + { + 'data.core.int.Int.': 42, + 'data.core.singlefile.SinglefileData.': 5, + } + ], + ) + ctime_by_day: dict[str, int] = pdt.Field( + description='Number of nodes created per day (YYYY-MM-DD).', + examples=[ + { + '2012-01-01': 10, + '2012-01-02': 15, + } + ], + ) + + +class NodeType(pdt.BaseModel): + """Pydantic model representing a node type.""" + + label: str = pdt.Field( + description='The class name of the node type.', + examples=['Int'], + ) + node_type: str = pdt.Field( + description='The AiiDA node type string.', + examples=['data.core.int.Int.'], + ) + nodes: str = pdt.Field( + description='The URL to access nodes of this type.', + examples=['../nodes?filters={"node_type":{"data.core.int.Int."}}'], + ) + projections: str = pdt.Field( + description='The URL to access projectable properties of this node type.', + examples=['../nodes/projections?type=data.core.int.Int.'], + ) + node_schema: str = pdt.Field( + description='The URL to access the schema of this node type.', + examples=['../nodes/schema?type=data.core.int.Int.'], + ) + + +class RepoFileMetadata(pdt.BaseModel): + """Pydantic model representing the metadata of a file in the AiiDA repository.""" + + type: t.Literal['FILE'] = pdt.Field( + description='The type of the repository object.', + examples=['FILE'], + ) + binary: bool = pdt.Field( + False, + description='Whether the file is binary.', + examples=[True], + ) + size: int = pdt.Field( + description='The size of the file in bytes.', + examples=[1024], + ) + download: str = pdt.Field( + description='The URL to download the file.', + examples=['../nodes/{uuid}/repo/contents?filename=path/to/file.txt'], + ) + + +class RepoDirMetadata(pdt.BaseModel): + """Pydantic model representing the metadata of a directory in the AiiDA repository.""" + + type: t.Literal['DIRECTORY'] = pdt.Field( + description='The type of the repository object.', + examples=['DIRECTORY'], + ) + objects: dict[str, t.Union[RepoFileMetadata, 'RepoDirMetadata']] = pdt.Field( + description='A dictionary with the metadata of the objects in the directory.', + examples=[ + { + 'file.txt': { + 'type': 'FILE', + 'binary': False, + 'size': 2048, + 'download': '../nodes/{uuid}/repo/contents?filename=path/to/file.txt', + }, + 'subdir': { + 'type': 'DIRECTORY', + 'objects': {}, + }, + } + ], + ) + + +MetadataType = t.Union[RepoFileMetadata, RepoDirMetadata] + + +class NodeLink(Node.Model): + link_label: str = pdt.Field(description='The label of the link to the node.') + link_type: str = pdt.Field(description='The type of the link to the node.') + + +class NodeModelRegistry: + """Registry for AiiDA REST API node models. + + This class maintains mappings of node types and their corresponding Pydantic models. + + :ivar ModelUnion: A union type of all AiiDA node Pydantic models, discriminated by the `node_type` field. + """ + + def __init__(self) -> None: + self._build_node_mappings() + self.ModelUnion = t.Annotated[ + t.Union[self._get_post_models()], + pdt.Field(discriminator='node_type'), + ] + + def get_node_types(self) -> list[str]: + """Get the list of registered node class names. + + :return: List of node class names. + """ + return list(self._models.keys()) + + def get_node_class_name(self, node_type: str) -> str: + """Get the AiiDA node class name for a given node type. + + :param node_type: The AiiDA node type string. + :return: The corresponding node class name. + """ + return node_type.rsplit('.', 2)[-2] + + def get_model(self, node_type: str, which: t.Literal['get', 'post'] = 'get') -> type[Node.Model]: + """Get the Pydantic model class for a given node type. + + :param node_type: The AiiDA node type string. + :return: The corresponding Pydantic model class. + """ + if (Model := self._models.get(node_type)) is None: + raise KeyError(f'Unknown node type: {node_type}') + if which not in Model: + raise KeyError(f'Unknown model type: {which}') + return Model[which] + + def _get_node_post_model(self, node_cls: Node) -> type[Node.Model]: + """Return a patched Model for the given node class with a literal `node_type` field. + + :param node_cls: The AiiDA node class. + :return: The patched ORM Node model. + """ + Model = node_cls.CreateModel + node_type = node_cls.class_node_type + # Here we patch in the `node_type` union descriminator field. + # We annotate it with `SkipJsonSchema` to keep it off the public openAPI schema. + Model.model_fields['node_type'] = pdt.fields.FieldInfo( + annotation=pdt.json_schema.SkipJsonSchema[t.Literal[node_type]], # type: ignore[misc,valid-type] + default=node_type, + ) + Model.model_rebuild(force=True) + return t.cast(type[Node.Model], Model) + + def _build_node_mappings(self) -> None: + """Build mapping of node type to node creation model.""" + self._models: dict[str, dict[str, type[Node.Model]]] = {} + entry_point: EntryPoint + for entry_point in get_entry_points('aiida.data'): + try: + node_cls = t.cast(Node, entry_point.load()) + except Exception as exception: + # Skip entry points that cannot be loaded + print(f'Warning: could not load entry point {entry_point.name}: {exception}') + continue + + self._models[node_cls.class_node_type] = { + 'get': node_cls.Model, + 'post': self._get_node_post_model(node_cls), + } + + def _get_post_models(self) -> tuple[type[Node.Model], ...]: + """Get a union type of all node 'post' models. + + :return: A union type of all node 'post' models. + """ + post_models = [model_dict['post'] for model_dict in self._models.values()] + return tuple(post_models) diff --git a/aiida_restapi/resources.py b/aiida_restapi/resources.py deleted file mode 100644 index 7713b3b9..00000000 --- a/aiida_restapi/resources.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Union - -from aiida.common.exceptions import EntryPointError, LoadingEntryPointError -from aiida.plugins.entry_point import get_entry_point_names, load_entry_point - -from aiida_restapi.exceptions import RestFeatureNotAvailable, RestInputValidationError -from aiida_restapi.identifiers import construct_full_type, load_entry_point_from_full_type - - -def get_all_download_formats(full_type: Union[str, None] = None) -> dict: - """Returns dict of possible node formats for all available node types""" - all_formats = {} - - if full_type: - try: - node_cls = load_entry_point_from_full_type(full_type) - except (TypeError, ValueError): - raise RestInputValidationError(f'The full type {full_type} is invalid.') - except EntryPointError: - raise RestFeatureNotAvailable('The download formats for this node type are not available.') - - try: - available_formats = node_cls.get_export_formats() - all_formats[full_type] = available_formats - except AttributeError: - pass - else: - entry_point_group = 'aiida.data' - - for name in get_entry_point_names(entry_point_group): - try: - node_cls = load_entry_point(entry_point_group, name) - available_formats = node_cls.get_export_formats() - except (AttributeError, LoadingEntryPointError): - continue - - if available_formats: - full_type = construct_full_type(node_cls.class_node_type, '') - all_formats[full_type] = available_formats - - return all_formats diff --git a/aiida_restapi/routers/auth.py b/aiida_restapi/routers/auth.py index a6b60bc4..4a371542 100644 --- a/aiida_restapi/routers/auth.py +++ b/aiida_restapi/routers/auth.py @@ -1,10 +1,12 @@ """Handle API authentication and authorization.""" -# pylint: disable=missing-function-docstring,missing-class-docstring -from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from __future__ import annotations + +import typing as t +from datetime import datetime, timedelta, timezone import bcrypt +from aiida import orm from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from fastapi import APIRouter, Depends, HTTPException, status @@ -13,7 +15,6 @@ from pydantic import BaseModel from aiida_restapi import config -from aiida_restapi.models import User class Token(BaseModel): @@ -25,16 +26,17 @@ class TokenData(BaseModel): email: str -class UserInDB(User): +class UserInDB(orm.User.Model): hashed_password: str - disabled: Optional[bool] = None + disabled: t.Optional[bool] = None pwd_context = PasswordHasher() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f'{config.API_CONFIG["PREFIX"]}/auth/token') -router = APIRouter() +read_router = APIRouter(prefix='/auth') +write_router = APIRouter(prefix='/auth') def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -54,14 +56,14 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) -def get_user(db: dict, email: str) -> Optional[UserInDB]: +def get_user(db: dict, email: str) -> UserInDB | None: if email in db: user_dict = db[email] return UserInDB(**user_dict) return None -def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[UserInDB]: +def authenticate_user(fake_db: dict, email: str, password: str) -> UserInDB | None: user = get_user(fake_db, email) if not user: @@ -73,18 +75,18 @@ def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[User return user -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: +def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: to_encode = data.copy() if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=15) + expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode.update({'exp': expire}) encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM) return encoded_jwt -async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: +async def get_current_user(token: t.Annotated[str, Depends(oauth2_scheme)]) -> orm.User.Model: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate credentials', @@ -92,7 +94,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: ) try: payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM]) - email: str = payload.get('sub') + email = payload.get('sub') if email is None: raise credentials_exception token_data = TokenData(email=email) @@ -105,17 +107,21 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: async def get_current_active_user( - current_user: UserInDB = Depends(get_current_user), + current_user: t.Annotated[UserInDB, Depends(get_current_user)], ) -> UserInDB: if current_user.disabled: raise HTTPException(status_code=400, detail='Inactive user') return current_user -@router.post('/token', response_model=Token) +@write_router.post( + '/token', + response_model=Token, +) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), -) -> Dict[str, Any]: +) -> dict[str, t.Any]: + """Login to get access token.""" user = authenticate_user(config.fake_users_db, form_data.username, form_data.password) if not user: raise HTTPException( @@ -128,12 +134,12 @@ async def login_for_access_token( return {'access_token': access_token, 'token_type': 'bearer'} -@router.get('/auth/me/', response_model=User) -async def read_users_me(current_user: User = Depends(get_current_active_user)) -> User: +@read_router.get( + '/me/', + response_model=orm.User.Model, +) +async def read_users_me( + current_user: t.Annotated[orm.User.Model, Depends(get_current_active_user)], +) -> orm.User.Model: + """Get the current authenticated user.""" return current_user - - -# @router.get('/users/me/items/') -# async def read_own_items( -# current_user: User = Depends(get_current_active_user)): -# return [{'item_id': 'Foo', 'owner': current_user.email}] diff --git a/aiida_restapi/routers/computers.py b/aiida_restapi/routers/computers.py index b97d4666..92b212de 100644 --- a/aiida_restapi/routers/computers.py +++ b/aiida_restapi/routers/computers.py @@ -1,53 +1,113 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for computers.""" -from typing import List, Optional +from __future__ import annotations + +import typing as t from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from aiida.orm.querybuilder import QueryBuilder -from fastapi import APIRouter, Depends - -from aiida_restapi.models import Computer, User - -from .auth import get_current_active_user - -router = APIRouter() +from aiida.common.exceptions import NotExistent +from fastapi import APIRouter, Depends, HTTPException, Query +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.services.entity import EntityService -@router.get('/computers', response_model=List[Computer]) -@with_dbenv() -async def read_computers() -> List[Computer]: - """Get list of all computers""" - - return Computer.get_entities() +from .auth import UserInDB, get_current_active_user +read_router = APIRouter(prefix='/computers') +write_router = APIRouter(prefix='/computers') -@router.get('/computers/projectable_properties', response_model=List[str]) -async def get_computers_projectable_properties() -> List[str]: - """Get projectable properties for computers endpoint""" +service = EntityService[orm.Computer, orm.Computer.Model](orm.Computer) - return Computer.get_projectable_properties() - -@router.get('/computers/{comp_id}', response_model=Computer) +@read_router.get( + '/schema', + response_model=dict, +) +async def get_computers_schema( + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve: "get" or "post"', + ), +) -> dict: + """Get JSON schema for AiiDA computers.""" + try: + return service.get_schema(which=which) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +async def get_computer_projections() -> list[str]: + """Get queryable projections for AiiDA computers.""" + return service.get_projections() + + +@read_router.get( + '', + response_model=PaginatedResults[orm.Computer.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() -async def read_computer(comp_id: int) -> Optional[Computer]: - """Get computer by id.""" - qbobj = QueryBuilder() - qbobj.append(orm.Computer, filters={'id': comp_id}, project='**', tag='computer').limit(1) - - return qbobj.dict()[0]['computer'] - - -@router.post('/computers', response_model=Computer) +async def get_computers( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Computer.Model]: + """Get AiiDA computers with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/{pk}', + response_model=orm.Computer.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_computer(pk: str) -> orm.Computer.Model: + """Get AiiDA computer by pk.""" + try: + return service.get_one(pk) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{pk}/metadata', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_computer_metadata(pk: str) -> dict[str, t.Any]: + """Get metadata of an AiiDA computer by pk.""" + try: + return service.get_field(pk, 'metadata') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '', + response_model=orm.Computer.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() async def create_computer( - computer: Computer, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Computer: + computer_model: orm.Computer.CreateModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Computer.Model: """Create new AiiDA computer.""" - orm_computer = orm.Computer(**computer.dict(exclude_unset=True)).store() - - return Computer.from_orm(orm_computer) + try: + return service.add_one(computer_model) + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) diff --git a/aiida_restapi/routers/daemon.py b/aiida_restapi/routers/daemon.py index 8f2f4258..9965d67a 100644 --- a/aiida_restapi/routers/daemon.py +++ b/aiida_restapi/routers/daemon.py @@ -9,10 +9,10 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field -from ..models import User -from .auth import get_current_active_user +from .auth import UserInDB, get_current_active_user -router = APIRouter() +read_router = APIRouter(prefix='/daemon') +write_router = APIRouter(prefix='/daemon') class DaemonStatusModel(BaseModel): @@ -22,7 +22,10 @@ class DaemonStatusModel(BaseModel): num_workers: t.Optional[int] = Field(description='The number of workers if the daemon is running.') -@router.get('/daemon/status', response_model=DaemonStatusModel) +@read_router.get( + '/status', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_status() -> DaemonStatusModel: """Return the daemon status.""" @@ -31,17 +34,21 @@ async def get_daemon_status() -> DaemonStatusModel: if not client.is_daemon_running: return DaemonStatusModel(running=False, num_workers=None) - response = client.get_numprocesses() + try: + response = client.get_numprocesses() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post('/daemon/start', response_model=DaemonStatusModel) +@write_router.post( + '/start', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_start( - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], ) -> DaemonStatusModel: """Start the daemon.""" client = get_daemon_client() @@ -51,20 +58,20 @@ async def get_daemon_start( try: client.start_daemon() + response = client.get_numprocesses() except DaemonException as exception: raise HTTPException(status_code=500, detail=str(exception)) from exception - response = client.get_numprocesses() - return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post('/daemon/stop', response_model=DaemonStatusModel) +@write_router.post( + '/stop', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_stop( - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], ) -> DaemonStatusModel: """Stop the daemon.""" client = get_daemon_client() diff --git a/aiida_restapi/routers/groups.py b/aiida_restapi/routers/groups.py index e7e017e6..32cd5ddc 100644 --- a/aiida_restapi/routers/groups.py +++ b/aiida_restapi/routers/groups.py @@ -1,51 +1,113 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for groups.""" -from typing import List, Optional +from __future__ import annotations + +import typing as t from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from fastapi import APIRouter, Depends - -from aiida_restapi.models import Group, Group_Post, User - -from .auth import get_current_active_user - -router = APIRouter() - - -@router.get('/groups', response_model=List[Group]) -@with_dbenv() -async def read_groups() -> List[Group]: - """Get list of all groups""" +from aiida.common.exceptions import NotExistent +from fastapi import APIRouter, Depends, HTTPException, Query - return Group.get_entities() +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.services.entity import EntityService +from .auth import UserInDB, get_current_active_user -@router.get('/groups/projectable_properties', response_model=List[str]) -async def get_groups_projectable_properties() -> List[str]: - """Get projectable properties for groups endpoint""" +read_router = APIRouter(prefix='/groups') +write_router = APIRouter(prefix='/groups') - return Group.get_projectable_properties() +service = EntityService[orm.Group, orm.Group.Model](orm.Group) -@router.get('/groups/{group_id}', response_model=Group) +@read_router.get( + '/schema', + response_model=dict, +) +async def get_groups_schema( + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve: "get" or "post"', + ), +) -> dict: + """Get JSON schema for AiiDA groups.""" + try: + return service.get_schema(which=which) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +async def get_group_projections() -> list[str]: + """Get queryable projections for AiiDA groups.""" + return service.get_projections() + + +@read_router.get( + '', + response_model=PaginatedResults[orm.Group.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() -async def read_group(group_id: int) -> Optional[Group]: - """Get group by id.""" - qbobj = orm.QueryBuilder() - - qbobj.append(orm.Group, filters={'id': group_id}, project='**', tag='group').limit(1) - return qbobj.dict()[0]['group'] - - -@router.post('/groups', response_model=Group) +async def get_groups( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Group.Model]: + """Get AiiDA groups with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/{uuid}', + response_model=orm.Group.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_group(uuid: str) -> orm.Group.Model: + """Get AiiDA group by uuid.""" + try: + return service.get_one(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{uuid}/extras', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_group_extras(uuid: str) -> dict[str, t.Any]: + """Get the extras of a group.""" + try: + return service.get_field(uuid, 'extras') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '', + response_model=orm.Group.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() async def create_group( - group: Group_Post, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Group: + group_model: orm.Group.CreateModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Group.Model: """Create new AiiDA group.""" - orm_group = orm.Group(**group.dict(exclude_unset=True)).store() - return Group.from_orm(orm_group) + try: + return service.add_one(group_model) + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 5f17a4e5..baef5b2c 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -1,57 +1,239 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for nodes.""" +from __future__ import annotations + +import io import json -import os -import tempfile -from pathlib import Path -from typing import Any, Generator, List, Optional +import typing as t +import pydantic as pdt from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent -from aiida.plugins.entry_point import load_entry_point -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi import APIRouter, Depends, Form, HTTPException, Query, UploadFile from fastapi.responses import StreamingResponse -from pydantic import ValidationError +from typing_extensions import TypeAlias + +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.config import API_CONFIG +from aiida_restapi.models.node import MetadataType, NodeModelRegistry, NodeStatistics, NodeType +from aiida_restapi.services.node import NodeLink, NodeService + +from .auth import UserInDB, get_current_active_user + +read_router = APIRouter(prefix='/nodes') +write_router = APIRouter(prefix='/nodes') + +service = NodeService[orm.Node, orm.Node.Model](orm.Node) -from aiida_restapi import models, resources +model_registry = NodeModelRegistry() -from .auth import get_current_active_user +if t.TYPE_CHECKING: + # Dummy type for static analysis + NodeModelUnion: TypeAlias = pdt.BaseModel +else: + # The real discriminated union built at runtime + NodeModelUnion = model_registry.ModelUnion -router = APIRouter() + +@read_router.get( + '/schema', + response_model=dict, +) +async def get_nodes_schema( + node_type: str | None = Query( + None, + description='The AiiDA node type string.', + alias='type', + ), + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve', + ), +) -> dict: + """Get JSON schema for the base AiiDA node 'get' model.""" + if not node_type: + return orm.Node.Model.model_json_schema() + try: + model = model_registry.get_model(node_type, which) + return model.model_json_schema() + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +@with_dbenv() +async def get_node_projections( + node_type: str | None = Query( + None, + description='The AiiDA node type string.', + alias='type', + ), +) -> list[str]: + """Get queryable projections for AiiDA nodes.""" + try: + return service.get_projections(node_type) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes', response_model=List[models.Node]) +@read_router.get( + '/statistics', + response_model=NodeStatistics, +) @with_dbenv() -async def read_nodes() -> List[models.Node]: - """Get list of all nodes""" - return models.Node.get_entities() +async def get_nodes_statistics(user: int | None = None) -> dict[str, t.Any]: + """Get node statistics.""" + + from aiida.manage import get_manager + + backend = get_manager().get_profile_storage() + return backend.query().get_creation_statistics(user_pk=user) + + +@read_router.get('/download_formats') +async def get_nodes_download_formats() -> dict[str, t.Any]: + """Get download formats for AiiDA nodes.""" + try: + return service.get_download_formats() + except EntryPointError as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/projectable_properties', response_model=List[str]) -async def get_nodes_projectable_properties() -> List[str]: - """Get projectable properties for nodes endpoint""" +@read_router.get( + '', + response_model=PaginatedResults[orm.Node.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_nodes( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Node.Model]: + """Get AiiDA nodes with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/types', + response_model=list[NodeType], +) +async def get_node_types() -> list: + """Get all node types in machine-actionable format.""" + api_prefix = API_CONFIG['PREFIX'] + return [ + { + 'label': model_registry.get_node_class_name(node_type), + 'node_type': node_type, + 'nodes': f'{api_prefix}/nodes?filters={{"node_type":"{node_type}"}}', + 'projections': f'{api_prefix}/nodes/projections?type={node_type}', + 'node_schema': f'{api_prefix}/nodes/schema?type={node_type}', + } + for node_type in sorted( + model_registry.get_node_types(), key=lambda node_type: model_registry.get_node_class_name(node_type) + ) + ] - return models.Node.get_projectable_properties() +@read_router.get( + '/{uuid}', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_node(uuid: str) -> orm.Node.Model: + """Get AiiDA node by uuid.""" + try: + return service.get_one(uuid) + except NotExistent as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/download_formats', response_model=dict[str, Any]) -async def get_nodes_download_formats() -> dict[str, Any]: - """Get download formats for nodes endpoint""" - return resources.get_all_download_formats() +@read_router.get( + '/{uuid}/attributes', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_node_attributes(uuid: str) -> dict[str, t.Any]: + """Get the attributes of a node.""" + try: + return service.get_field(uuid, 'attributes') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/{nodes_id}/download') +@read_router.get( + '/{uuid}/extras', + response_model=dict[str, t.Any], +) @with_dbenv() -async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse: - """Get nodes by id.""" - from aiida.orm import load_node +async def get_node_extras(uuid: str) -> dict[str, t.Any]: + """Get the extras of a node.""" + try: + return service.get_field(uuid, 'extras') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + +@read_router.get( + '/{uuid}/links', + response_model=PaginatedResults[NodeLink], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_node_links( + uuid: str, + queries: t.Annotated[QueryParams, Depends(query_params)], + direction: t.Literal['incoming', 'outgoing'] = Query( + description='Specify whether to retrieve incoming or outgoing links.', + ), +) -> PaginatedResults[NodeLink]: + """Get the incoming or outgoing links of a node.""" try: - node = load_node(nodes_id) - except NotExistent: - raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}') + return service.get_links(uuid, queries, direction=direction) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{uuid}/download', + response_class=StreamingResponse, +) +@with_dbenv() +async def download_node( + uuid: str, + download_format: str | None = Query( + None, + description='Format to download the node in', + ), +) -> StreamingResponse: + """Download AiiDA node by uuid in a given download format provided as a query parameter.""" + try: + node = orm.load_node(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception if download_format is None: raise HTTPException( @@ -63,14 +245,13 @@ async def download_node(nodes_id: int, download_format: Optional[str] = None) -> elif download_format in node.get_export_formats(): # byteobj, dict with {filename: filecontent} - import io try: exported_bytes, _ = node._exportcontent(download_format) - except LicensingException as exc: - raise HTTPException(status_code=500, detail=str(exc)) + except LicensingException as exception: + raise HTTPException(status_code=403, detail=str(exception)) from exception - def stream() -> Generator[bytes, None, None]: + def stream() -> t.Generator[bytes, None, None]: with io.BytesIO(exported_bytes) as handler: yield from handler @@ -85,91 +266,134 @@ def stream() -> Generator[bytes, None, None]: ) -@router.get('/nodes/{nodes_id}', response_model=models.Node) +@read_router.get( + '/{uuid}/repo/metadata', + response_model=dict[str, MetadataType], +) @with_dbenv() -async def read_node(nodes_id: int) -> Optional[models.Node]: - """Get nodes by id.""" - qbobj = orm.QueryBuilder() - qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) - return qbobj.dict()[0]['node'] +async def get_node_repo_file_metadata(uuid: str) -> dict[str, dict]: + """Get the repository file metadata of a node.""" + try: + return service.get_repository_metadata(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.post('/nodes', response_model=models.Node) +@read_router.get( + '/{uuid}/repo/contents', + response_class=StreamingResponse, +) @with_dbenv() -async def create_node( - node: models.Node_Post, - current_user: models.User = Depends( # pylint: disable=unused-argument - get_current_active_user +async def get_node_repo_file_contents( + uuid: str, + filename: str | None = Query( + None, + description='Filename of repository content to retrieve', ), -) -> models.Node: - """Create new AiiDA node.""" - node_dict = node.dict(exclude_unset=True) - entry_point = node_dict.pop('entry_point', None) +) -> StreamingResponse: + """Get the repository contents of a node.""" + from urllib.parse import quote try: - cls = load_entry_point(group='aiida.data', name=entry_point) - except EntryPointError as exception: + node = orm.load_node(uuid) + except NotExistent as exception: raise HTTPException(status_code=404, detail=str(exception)) from exception - try: - orm_object = models.Node_Post.create_new_node(cls, node_dict) - except (TypeError, ValueError, KeyError) as exception: - raise HTTPException(status_code=400, detail=str(exception)) from exception + repo = node.base.repository - return models.Node.from_orm(orm_object) + if filename: + try: + file_content = repo.get_object_content(filename, mode='rb') + except FileNotFoundError as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + def file_stream() -> t.Generator[bytes, None, None]: + with io.BytesIO(file_content) as handler: + yield from handler -@router.post('/nodes/singlefile', response_model=models.Node) -@with_dbenv() -async def create_upload_file( - params: str = Form(...), - upload_file: UploadFile = File(...), - current_user: models.User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> models.Node: - """Endpoint for uploading file data + download_name = filename.split('/')[-1] or 'download' + quoted = quote(download_name) + headers = {'Content-Disposition': f"attachment; filename={download_name!r}; filename*=UTF-8''{quoted}"} + + return StreamingResponse(file_stream(), media_type='application/octet-stream', headers=headers) + + else: + zip_bytes = repo.get_zipped_objects() + + def zip_stream() -> t.Generator[bytes, None, None]: + with io.BytesIO(zip_bytes) as handler: + yield from handler - Note that in this multipart form case, json input can't be used. - Get the parameters as a string and manually pass through pydantic. - """ + download_name = f'node_{uuid}_repo.zip' + quoted = quote(download_name) + headers = {'Content-Disposition': f"attachment; filename={download_name!r}; filename*=UTF-8''{quoted}"} + + return StreamingResponse(zip_stream(), media_type='application/zip', headers=headers) + + +@write_router.post( + '', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def create_node( + model: NodeModelUnion, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Node.Model: + """Create new AiiDA node.""" + try: + return service.add_one(model) + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '/file-upload', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def create_node_with_files( + params: t.Annotated[str, Form()], + files: list[UploadFile], + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Node.Model: + """Create new AiiDA node with files.""" try: - # Parse the JSON string into a dictionary - params_dict = json.loads(params) - # Validate against the Pydantic model - params_obj = models.Node_Post(**params_dict) + parameters = t.cast(dict, json.loads(params)) except json.JSONDecodeError as exception: - raise HTTPException( - status_code=400, - detail=f'Invalid JSON format: {exception!s}', - ) from exception - except ValidationError as exception: - raise HTTPException( - status_code=422, - detail=f'Validation failed: {exception}', - ) from exception + raise HTTPException(400, str(exception)) from exception - node_dict = params_obj.dict(exclude_unset=True) - entry_point = node_dict.pop('entry_point', None) + if not (node_type := parameters.get('node_type')): + raise HTTPException(422, "Missing 'node_type' in params") try: - cls = load_entry_point(group='aiida.data', name=entry_point) - except EntryPointError as exception: - raise HTTPException( - status_code=404, - detail=f'Could not load entry point: {exception}', - ) from exception + model_cls = model_registry.get_model(node_type, which='post') + model = model_cls(**parameters) + except KeyError as exception: + raise HTTPException(422, str(exception)) from exception + except pdt.ValidationError as exception: + raise HTTPException(422, str(exception)) from exception - with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file: - # Todo: read in chunks - content = await upload_file.read() - temp_file.write(content) - temp_path = temp_file.name + files_dict: dict[str, UploadFile] = {} - orm_object = models.Node_Post.create_new_node_with_file(cls, node_dict, Path(temp_path)) + for upload in files: + if (target := upload.filename) in files_dict: + raise HTTPException(422, f"Duplicate target path '{target}' in upload") + files_dict[target] = upload - # Clean up the temporary file - if os.path.exists(temp_path): - os.unlink(temp_path) - - return models.Node.from_orm(orm_object) + try: + return service.add_one(model, files=files_dict) + except json.JSONDecodeError as exception: + raise HTTPException(status_code=400, detail=str(exception)) from exception + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception diff --git a/aiida_restapi/routers/process.py b/aiida_restapi/routers/process.py deleted file mode 100644 index 905a468f..00000000 --- a/aiida_restapi/routers/process.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Declaration of FastAPI router for processes.""" - -from typing import List, Optional - -from aiida import orm -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import NotExistent -from aiida.engine import submit -from aiida.orm.querybuilder import QueryBuilder -from aiida.plugins import load_entry_point_from_string -from fastapi import APIRouter, Depends, HTTPException - -from aiida_restapi.models import Process, Process_Post, User - -from .auth import get_current_active_user - -router = APIRouter() - - -def process_inputs(inputs: dict) -> dict: - """Process the inputs dictionary converting each node UUID into the corresponding node by loading it. - - A node UUID is indicated by the key ending with the suffix ``.uuid``. - - :param inputs: The inputs dictionary. - :returns: The deserialized inputs dictionary. - :raises HTTPException: If the inputs contain a UUID that does not correspond to an existing node. - """ - uuid_suffix = '.uuid' - results = {} - - for key, value in inputs.items(): - if isinstance(value, dict): - results[key] = process_inputs(value) - elif key.endswith(uuid_suffix): - try: - results[key[: -len(uuid_suffix)]] = orm.load_node(uuid=value) - except NotExistent as exc: - raise HTTPException( - status_code=404, - detail=f'Node with UUID `{value}` does not exist.', - ) from exc - else: - results[key] = value - - return results - - -@router.get('/processes', response_model=List[Process]) -@with_dbenv() -async def read_processes() -> List[Process]: - """Get list of all processes""" - - return Process.get_entities() - - -@router.get('/processes/projectable_properties', response_model=List[str]) -async def get_processes_projectable_properties() -> List[str]: - """Get projectable properties for processes endpoint""" - - return Process.get_projectable_properties() - - -@router.get('/processes/{proc_id}', response_model=Process) -@with_dbenv() -async def read_process(proc_id: int) -> Optional[Process]: - """Get process by id.""" - qbobj = QueryBuilder() - qbobj.append(orm.ProcessNode, filters={'id': proc_id}, project='**', tag='process').limit(1) - - return qbobj.dict()[0]['process'] - - -@router.post('/processes', response_model=Process) -@with_dbenv() -async def post_process( - process: Process_Post, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Optional[Process]: - """Create new process.""" - process_dict = process.dict(exclude_unset=True, exclude_none=True) - inputs = process_inputs(process_dict['inputs']) - entry_point = process_dict.get('process_entry_point') - - try: - entry_point_process = load_entry_point_from_string(entry_point) - except ValueError as exc: - raise HTTPException( - status_code=404, - detail=f"Entry point '{entry_point}' not recognized.", - ) from exc - - process_node = submit(entry_point_process, **inputs) - - return process_node diff --git a/aiida_restapi/routers/querybuilder.py b/aiida_restapi/routers/querybuilder.py new file mode 100644 index 00000000..b8df2bc3 --- /dev/null +++ b/aiida_restapi/routers/querybuilder.py @@ -0,0 +1,158 @@ +"""Declaration of FastAPI router for AiiDA's QueryBuilder.""" + +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida import orm +from aiida.cmdline.utils.decorators import with_dbenv +from fastapi import APIRouter, HTTPException, Query + +from aiida_restapi.common.pagination import PaginatedResults + +read_router = APIRouter(prefix='/querybuilder') + + +class QueryBuilderPathItem(pdt.BaseModel): + """Pydantic model for QueryBuilder path items.""" + + entity_type: str | list[str] = pdt.Field( + description='The AiiDA entity type.', + ) + orm_base: str = pdt.Field( + description='The ORM base class of the entity.', + ) + tag: str | None = pdt.Field( + None, + description='An optional tag for the path item.', + ) + joining_keyword: str | None = pdt.Field( + None, + description='The joining keyword for relationships (e.g., "input", "output").', + ) + joining_value: str | None = pdt.Field( + None, + description='The joining value for relationships (e.g., "input", "output").', + ) + edge_tag: str | None = pdt.Field( + None, + description='An optional tag for the edge.', + ) + outerjoin: bool = pdt.Field( + False, + description='Whether to perform an outer join.', + ) + + +class QueryBuilderDict(pdt.BaseModel): + """Pydantic model for QueryBuilder POST requests.""" + + path: list[str | QueryBuilderPathItem] = pdt.Field( + description='The QueryBuilder path as a list of entity types or path items.', + examples=[ + [ + ['data.core.int.Int.', 'data.core.float.Float.'], + { + 'entity_type': 'data.core.int.Int.', + 'orm_base': 'node', + 'tag': 'integers', + }, + { + 'entity_type': ['data.core.int.Int.', 'data.core.float.Float.'], + 'orm_base': 'node', + 'tag': 'numbers', + }, + ] + ], + ) + filters: dict[str, dict[str, t.Any]] | None = pdt.Field( + None, + description='The QueryBuilder filters as a dictionary mapping tags to filter conditions.', + examples=[ + { + 'integers': {'attributes.value': {'<': 42}}, + } + ], + ) + project: dict[str, str | list[str]] | None = pdt.Field( + None, + description='The QueryBuilder projection as a dictionary mapping tags to attributes to project.', + examples=[ + { + 'integers': ['uuid', 'attributes.value'], + } + ], + ) + limit: pdt.NonNegativeInt | None = pdt.Field( + 10, + description='The maximum number of results to return.', + examples=[5], + ) + offset: pdt.NonNegativeInt | None = pdt.Field( + 0, + description='The number of results to skip before starting to collect the result set.', + examples=[0], + ) + order_by: str | list[str] | dict[str, t.Any] | None = pdt.Field( + None, + description='The QueryBuilder order_by as a string, list of strings, ' + 'or dictionary mapping tags to order conditions.', + examples=[ + {'integers': {'pk': 'desc'}}, + ], + ) + distinct: bool = pdt.Field( + False, + description='Whether to return only distinct results.', + examples=[False], + ) + + +@read_router.post( + '', + response_model=PaginatedResults, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def query_builder( + query: QueryBuilderDict, + flat: bool = Query( + False, + description='Whether to return results flat.', + ), +) -> PaginatedResults: + """Execute a QueryBuilder query based on the provided dictionary.""" + query_dict = query.model_dump() + + try: + limit = query_dict.pop('limit', 10) + offset = query_dict.get('offset', 0) + + # Get total count before applying the limit + qb = orm.QueryBuilder.from_dict(query_dict) + total = qb.count() + qb.limit(limit) + + # Run query builder + project = t.cast(dict, query_dict.get('project')) + if not project or any(p in ('*', ['*']) for p in project.values()): + # Projecting entities as entity models + return PaginatedResults[orm.Entity.Model]( + total=total, + page=offset // limit + 1, + page_size=limit, + results=[entity.to_model(minimal=True) for entity in t.cast(list[orm.Entity], qb.all(flat=True))], + ) + else: + # Projecting specific attributes + return PaginatedResults[t.Any]( + total=total, + page=offset // limit + 1, + page_size=limit, + results=qb.all(flat=flat), + ) + + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/aiida_restapi/routers/server.py b/aiida_restapi/routers/server.py new file mode 100644 index 00000000..c9fcf507 --- /dev/null +++ b/aiida_restapi/routers/server.py @@ -0,0 +1,227 @@ +"""Declaration of FastAPI application.""" + +from __future__ import annotations + +import re + +import pydantic as pdt +from aiida import __version__ as aiida_version +from fastapi import APIRouter, Request +from fastapi.responses import HTMLResponse +from fastapi.routing import APIRoute +from starlette.routing import Route + +from aiida_restapi.config import API_CONFIG + +read_router = APIRouter(prefix='/server') + + +class ServerInfo(pdt.BaseModel): + """API version information.""" + + API_major_version: str = pdt.Field( + description='Major version of the API', + examples=['0'], + ) + API_minor_version: str = pdt.Field( + description='Minor version of the API', + examples=['1'], + ) + API_revision_version: str = pdt.Field( + description='Revision version of the API', + examples=['0a1'], + ) + API_prefix: str = pdt.Field( + description='Prefix for all API endpoints', + examples=['/api/v0'], + ) + AiiDA_version: str = pdt.Field( + description='Version of the AiiDA installation', + examples=['2.7.2.post0'], + ) + + +@read_router.get( + '/info', + response_model=ServerInfo, +) +async def get_server_info() -> ServerInfo: + """Get the API version information.""" + api_version = API_CONFIG['VERSION'].split('.') + return ServerInfo( + API_major_version=api_version[0], + API_minor_version=api_version[1], + API_revision_version=api_version[2], + API_prefix=API_CONFIG['PREFIX'], + AiiDA_version=aiida_version, + ) + + +class ServerEndpoint(pdt.BaseModel): + """API endpoint.""" + + path: str = pdt.Field( + description='Path of the endpoint', + examples=['../server/endpoints'], + ) + group: str | None = pdt.Field( + description='Group of the endpoint', + examples=['server'], + ) + methods: set[str] = pdt.Field( + description='HTTP methods supported by the endpoint', + examples=['GET'], + ) + description: str = pdt.Field( + '-', + description='Description of the endpoint', + examples=['Get a JSON-serializable dictionary of all registered API routes.'], + ) + + +@read_router.get( + '/endpoints', + response_model=dict[str, list[ServerEndpoint]], +) +async def get_server_endpoints(request: Request) -> dict[str, list[ServerEndpoint]]: + """Get a JSON-serializable dictionary of all registered API routes.""" + endpoints: list[ServerEndpoint] = [] + + for route in request.app.routes: + if route.path == '/': + continue + + group, methods, description = _get_route_parts(route) + base_url = str(request.base_url).rstrip('/') + + endpoint = { + 'path': base_url + route.path, + 'group': group, + 'methods': methods, + 'description': description, + } + + endpoints.append(ServerEndpoint(**endpoint)) + + return {'endpoints': endpoints} + + +@read_router.get( + '/endpoints/table', + name='endpoints', + response_class=HTMLResponse, +) +async def get_server_endpoints_table(request: Request) -> HTMLResponse: + """Get an HTML table of all registered API routes.""" + routes = request.app.routes + base_url = str(request.base_url).rstrip('/') + + rows = [] + + for route in routes: + if route.path == '/': + continue + + path = base_url + route.path + group, methods, description = _get_route_parts(route) + + disable_url = ( + ( + isinstance(route, APIRoute) + and any( + param + for param in route.dependant.path_params + + route.dependant.query_params + + route.dependant.body_params + if param.required + ) + ) + or (route.methods and 'POST' in route.methods) + or 'auth' in path + ) + + path_row = path if disable_url else f'{path}' + + rows.append(f""" +
| URL | +Group | +Methods | +Description | +
|---|