diff --git a/aiida_restapi/__init__.py b/aiida_restapi/__init__.py index d3618b6e..27f6e987 100644 --- a/aiida_restapi/__init__.py +++ b/aiida_restapi/__init__.py @@ -1,5 +1,3 @@ -"""AiiDA REST API for data queries and workflow managment.""" +"""AiiDA REST API for data queries and workflow management.""" __version__ = '0.1.0a1' - -from .main import app # noqa: F401 diff --git a/aiida_restapi/routers/comments.py b/aiida_restapi/cli/__init__.py similarity index 100% rename from aiida_restapi/routers/comments.py rename to aiida_restapi/cli/__init__.py diff --git a/aiida_restapi/cli/main.py b/aiida_restapi/cli/main.py new file mode 100644 index 00000000..f72e2a37 --- /dev/null +++ b/aiida_restapi/cli/main.py @@ -0,0 +1,31 @@ +import os + +import click +import uvicorn + + +@click.group() +def cli() -> None: + """AiiDA REST API management CLI.""" + + +@cli.command() +@click.option('--host', default='127.0.0.1', show_default=True) +@click.option('--port', default=8000, show_default=True, type=int) +@click.option('--read-only', is_flag=True) +@click.option('--watch', is_flag=True) +def start(read_only: bool, watch: bool, host: str, port: int) -> None: + """Start the AiiDA REST API service.""" + + os.environ['AIIDA_RESTAPI_READ_ONLY'] = '1' if read_only else '0' + + click.echo(f'Starting REST API (read_only={read_only}, watch={watch}) on {host}:{port}') + + uvicorn.run( + 'aiida_restapi.main:create_app', + host=host, + port=port, + reload=watch, + reload_dirs=['aiida_restapi'], + factory=True, + ) diff --git a/aiida_restapi/common/__init__.py b/aiida_restapi/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_restapi/common/pagination.py b/aiida_restapi/common/pagination.py new file mode 100644 index 00000000..2c0f57b0 --- /dev/null +++ b/aiida_restapi/common/pagination.py @@ -0,0 +1,19 @@ +"""Pagination utilities.""" + +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida.orm import Entity + +ResultType = t.TypeVar('ResultType', bound=Entity.Model) + +__all__ = ('PaginatedResults',) + + +class PaginatedResults(pdt.BaseModel, t.Generic[ResultType]): + total: int + page: int + page_size: int + results: list[ResultType] diff --git a/aiida_restapi/common/query.py b/aiida_restapi/common/query.py new file mode 100644 index 00000000..8f496635 --- /dev/null +++ b/aiida_restapi/common/query.py @@ -0,0 +1,90 @@ +"""REST API query utilities.""" + +from __future__ import annotations + +import json +import typing as t + +import pydantic as pdt +from fastapi import HTTPException, Query + + +class QueryParams(pdt.BaseModel): + filters: dict[str, t.Any] = pdt.Field( + default_factory=dict, + description='AiiDA QueryBuilder filters', + examples=[ + {'node_type': {'==': 'data.core.int.Int.'}}, + {'attributes.value': {'>': 42}}, + ], + ) + order_by: str | list[str] | dict[str, t.Any] | None = pdt.Field( + None, + description='Fields to sort by', + examples=[ + {'attributes.value': 'desc'}, + ], + ) + page_size: pdt.PositiveInt = pdt.Field( + 10, + description='Number of results per page', + examples=[10], + ) + page: pdt.PositiveInt = pdt.Field( + 1, + description='Page number', + examples=[1], + ) + + +def query_params( + filters: str | None = Query( + None, + description='AiiDA QueryBuilder filters as JSON string', + ), + order_by: str | None = Query( + None, + description='Comma-separated list of fields to sort by', + ), + page_size: pdt.PositiveInt = Query( + 10, + description='Number of results per page', + ), + page: pdt.PositiveInt = Query( + 1, + description='Page number', + ), +) -> QueryParams: + """Parse query parameters into a structured object. + + :param filters: AiiDA QueryBuilder filters as JSON string. + :param order_by: Comma-separated string of fields to sort by. + :param page_size: Number of results per page. + :param page: Page number. + :return: Structured query parameters. + :raises HTTPException: If filters cannot be parsed as JSON. + """ + query_filters: dict[str, t.Any] = {} + query_order_by: str | list[str] | dict[str, t.Any] | None = None + if filters: + try: + query_filters = json.loads(filters) + except Exception as exception: + raise HTTPException( + status_code=400, + detail=f'Could not parse filters as JSON: {exception}', + ) from exception + if order_by: + try: + query_order_by = json.loads(order_by) + except Exception as exception: + raise HTTPException( + status_code=400, + detail=f'Could not parse order_by as JSON: {exception}', + ) from exception + return QueryParams( + filters=query_filters, + order_by=query_order_by, + page_size=page_size, + page=page, + ) diff --git a/aiida_restapi/common/types.py b/aiida_restapi/common/types.py new file mode 100644 index 00000000..0b4c2014 --- /dev/null +++ b/aiida_restapi/common/types.py @@ -0,0 +1,13 @@ +"""Common type variables.""" + +from __future__ import annotations + +import typing as t + +from aiida import orm + +EntityType = t.TypeVar('EntityType', bound='orm.Entity') +EntityModelType = t.TypeVar('EntityModelType', bound='orm.Entity.Model') + +NodeType = t.TypeVar('NodeType', bound='orm.Node') +NodeModelType = t.TypeVar('NodeModelType', bound='orm.Node.Model') diff --git a/aiida_restapi/config.py b/aiida_restapi/config.py index 0316437d..5145211c 100644 --- a/aiida_restapi/config.py +++ b/aiida_restapi/config.py @@ -1,5 +1,7 @@ """Configuration of API""" +from aiida_restapi import __version__ + # to get a string like this run: # openssl rand -hex 32 SECRET_KEY = '09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7' @@ -18,3 +20,8 @@ 'disabled': False, } } + +API_CONFIG = { + 'PREFIX': '/api/v0', + 'VERSION': __version__, +} diff --git a/aiida_restapi/aiida_db_mappings.py b/aiida_restapi/graphql/aiida_db_mappings.py similarity index 100% rename from aiida_restapi/aiida_db_mappings.py rename to aiida_restapi/graphql/aiida_db_mappings.py diff --git a/aiida_restapi/graphql/comments.py b/aiida_restapi/graphql/comments.py index 49467ac7..c0003718 100644 --- a/aiida_restapi/graphql/comments.py +++ b/aiida_restapi/graphql/comments.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Comment -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .orm_factories import ( ENTITY_DICT_TYPE, diff --git a/aiida_restapi/graphql/computers.py b/aiida_restapi/graphql/computers.py index bc877186..520e1522 100644 --- a/aiida_restapi/graphql/computers.py +++ b/aiida_restapi/graphql/computers.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Computer -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.plugins import QueryPlugin from .nodes import NodesQuery diff --git a/aiida_restapi/filter_syntax.py b/aiida_restapi/graphql/filter_syntax.py similarity index 98% rename from aiida_restapi/filter_syntax.py rename to aiida_restapi/graphql/filter_syntax.py index ba7010f2..6861fe30 100644 --- a/aiida_restapi/filter_syntax.py +++ b/aiida_restapi/graphql/filter_syntax.py @@ -13,8 +13,8 @@ from lark import Lark, Token, Tree -from . import static -from .utils import parse_date +from .. import static +from ..utils import parse_date FILTER_GRAMMAR = resources.open_text(static, 'filter_grammar.lark') diff --git a/aiida_restapi/graphql/groups.py b/aiida_restapi/graphql/groups.py index 21f7b121..b811c825 100644 --- a/aiida_restapi/graphql/groups.py +++ b/aiida_restapi/graphql/groups.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Group -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.nodes import NodesQuery from aiida_restapi.graphql.plugins import QueryPlugin diff --git a/aiida_restapi/graphql/logs.py b/aiida_restapi/graphql/logs.py index eebd2e6f..aaaace5e 100644 --- a/aiida_restapi/graphql/logs.py +++ b/aiida_restapi/graphql/logs.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import Log -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .orm_factories import ( ENTITY_DICT_TYPE, diff --git a/aiida_restapi/graphql/nodes.py b/aiida_restapi/graphql/nodes.py index b0e0fae1..8766c7da 100644 --- a/aiida_restapi/graphql/nodes.py +++ b/aiida_restapi/graphql/nodes.py @@ -6,7 +6,7 @@ import graphene as gr from aiida import orm -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from aiida_restapi.graphql.plugins import QueryPlugin from .comments import CommentsQuery diff --git a/aiida_restapi/graphql/orm_factories.py b/aiida_restapi/graphql/orm_factories.py index fade6fa0..1c091216 100644 --- a/aiida_restapi/graphql/orm_factories.py +++ b/aiida_restapi/graphql/orm_factories.py @@ -12,7 +12,7 @@ from graphql import GraphQLError from pydantic import Json -from aiida_restapi.aiida_db_mappings import ORM_MAPPING, get_model_from_orm +from aiida_restapi.graphql.aiida_db_mappings import ORM_MAPPING, get_model_from_orm from .config import ENTITY_LIMIT from .utils import JSON, selected_field_names_naive diff --git a/aiida_restapi/graphql/users.py b/aiida_restapi/graphql/users.py index cc0a8077..473d907c 100644 --- a/aiida_restapi/graphql/users.py +++ b/aiida_restapi/graphql/users.py @@ -6,7 +6,7 @@ import graphene as gr from aiida.orm import User -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str from .nodes import NodesQuery from .orm_factories import ( diff --git a/aiida_restapi/identifiers.py b/aiida_restapi/identifiers.py deleted file mode 100644 index e00a4668..00000000 --- a/aiida_restapi/identifiers.py +++ /dev/null @@ -1,170 +0,0 @@ -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions to work with node "full types" which identify node types. - -A node's `full_type` is defined as a string that uniquely defines the node type. A valid `full_type` is constructed by -concatenating the `node_type` and `process_type` of a node with the `FULL_TYPE_CONCATENATOR`. Each segment of the full -type can optionally be terminated by a single `LIKE_OPERATOR_CHARACTER` to indicate that the `node_type` or -`process_type` should start with that value but can be followed by any amount of other characters. A full type is -invalid if it does not contain exactly one `FULL_TYPE_CONCATENATOR` character. Additionally, each segment can contain -at most one occurrence of the `LIKE_OPERATOR_CHARACTER` and it has to be at the end of the segment. - -Examples of valid full types: - - 'data.bool.Bool.|' - 'process.calculation.calcfunction.%|%' - 'process.calculation.calcjob.CalcJobNode.|aiida.calculations:arithmetic.add' - 'process.calculation.calcfunction.CalcFunctionNode.|aiida.workflows:codtools.primitive_structure_from_cif' - -Examples of invalid full types: - - 'data.bool' # Only a single segment without concatenator - 'data.|bool.Bool.|process.' # More than one concatenator - 'process.calculation%.calcfunction.|aiida.calculations:arithmetic.add' # Like operator not at end of segment - 'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment - -""" - -from typing import Any - -from aiida.common.escaping import escape_for_sql_like - -FULL_TYPE_CONCATENATOR = '|' -LIKE_OPERATOR_CHARACTER = '%' -DEFAULT_NAMESPACE_LABEL = '~no-entry-point~' - - -def validate_full_type(full_type: str) -> None: - """Validate that the `full_type` is a valid full type unique node identifier. - - :param full_type: a `Node` full type - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - """ - from aiida.common.lang import type_check - - type_check(full_type, str) - - if FULL_TYPE_CONCATENATOR not in full_type: - raise ValueError( - f'full type `{full_type}` does not include the required concatenator symbol `{FULL_TYPE_CONCATENATOR}`.' - ) - elif full_type.count(FULL_TYPE_CONCATENATOR) > 1: - raise ValueError( - f'full type `{full_type}` includes the concatenator symbol `{FULL_TYPE_CONCATENATOR}` more than once.' - ) - - -def construct_full_type(node_type: str, process_type: str) -> str: - """Return the full type, which fully identifies the type of any `Node` with the given `node_type` and - `process_type`. - - :param node_type: the `node_type` of the `Node` - :param process_type: the `process_type` of the `Node` - :return: the full type, which is a unique identifier - """ - if node_type is None: - node_type = '' - - if process_type is None: - process_type = '' - - return f'{node_type}{FULL_TYPE_CONCATENATOR}{process_type}' - - -def get_full_type_filters(full_type: str) -> dict[str, Any]: - """Return the `QueryBuilder` filters that will return all `Nodes` identified by the given `full_type`. - - :param full_type: the `full_type` node type identifier - :return: dictionary of filters to be passed for the `filters` keyword in `QueryBuilder.append` - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - """ - validate_full_type(full_type) - - filters: dict[str, Any] = {} - node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) - - for entry in (node_type, process_type): - if entry.count(LIKE_OPERATOR_CHARACTER) > 1: - raise ValueError(f'full type component `{entry}` contained more than one like-operator character') - - if LIKE_OPERATOR_CHARACTER in entry and entry[-1] != LIKE_OPERATOR_CHARACTER: - raise ValueError(f'like-operator character in full type component `{entry}` is not at the end') - - if LIKE_OPERATOR_CHARACTER in node_type: - # Remove the trailing `LIKE_OPERATOR_CHARACTER`, escape the string and reattach the character - node_type = node_type[:-1] - node_type = escape_for_sql_like(node_type) + LIKE_OPERATOR_CHARACTER - filters['node_type'] = {'like': node_type} - else: - filters['node_type'] = {'==': node_type} - - if LIKE_OPERATOR_CHARACTER in process_type: - # Remove the trailing `LIKE_OPERATOR_CHARACTER` () - # If that was the only specification, just ignore this filter (looking for any process_type) - # If there was more: escape the string and reattach the character - process_type = process_type[:-1] - if process_type: - process_type = escape_for_sql_like(process_type) + LIKE_OPERATOR_CHARACTER - filters['process_type'] = {'like': process_type} - elif process_type: - filters['process_type'] = {'==': process_type} - else: - # A `process_type=''` is used to represents both `process_type='' and `process_type=None`. - # This is because there is no simple way to single out null `process_types`, and therefore - # we consider them together with empty-string process_types. - # Moreover, the existence of both is most likely a bug of migrations and thus both share - # this same "erroneous" origin. - filters['process_type'] = {'or': [{'==': ''}, {'==': None}]} - - return filters - - -def load_entry_point_from_full_type(full_type: str) -> Any: - """Return the loaded entry point for the given `full_type` unique node type identifier. - - :param full_type: the `full_type` unique node type identifier - :raises ValueError: if the `full_type` is invalid - :raises TypeError: if the `full_type` is not a string type - :raises `~aiida.common.exceptions.EntryPointError`: if the corresponding entry point cannot be loaded - """ - from aiida.common import EntryPointError - from aiida.common.utils import strip_prefix - from aiida.plugins.entry_point import ( - is_valid_entry_point_string, - load_entry_point, - load_entry_point_from_string, - ) - - data_prefix = 'data.' - - validate_full_type(full_type) - - node_type, process_type = full_type.split(FULL_TYPE_CONCATENATOR) - - if is_valid_entry_point_string(process_type): - try: - return load_entry_point_from_string(process_type) - except EntryPointError: - raise EntryPointError(f'could not load entry point `{process_type}`') - - elif node_type.startswith(data_prefix): - base_name = strip_prefix(node_type, data_prefix) - entry_point_name = base_name.rsplit('.', 2)[0] - - try: - return load_entry_point('aiida.data', entry_point_name) - except EntryPointError: - raise EntryPointError(f'could not load entry point `{process_type}`') - - # Here we are dealing with a `ProcessNode` with a `process_type` that is not an entry point string. - # Which means it is most likely a full module path (the fallback option) and we cannot necessarily load the - # class from this. We could try with `importlib` but not sure that we should - raise EntryPointError('entry point of the given full type cannot be loaded') diff --git a/aiida_restapi/main.py b/aiida_restapi/main.py index 5e75b8c3..d85bf9ca 100644 --- a/aiida_restapi/main.py +++ b/aiida_restapi/main.py @@ -1,16 +1,45 @@ """Declaration of FastAPI application.""" -from fastapi import FastAPI +import os +from fastapi import APIRouter, FastAPI +from fastapi.responses import RedirectResponse + +from aiida_restapi.config import API_CONFIG from aiida_restapi.graphql import main -from aiida_restapi.routers import auth, computers, daemon, groups, nodes, process, users - -app = FastAPI() -app.include_router(auth.router) -app.include_router(computers.router) -app.include_router(daemon.router) -app.include_router(nodes.router) -app.include_router(groups.router) -app.include_router(users.router) -app.include_router(process.router) -app.add_route('/graphql', main.app, name='graphql') +from aiida_restapi.routers import auth, computers, daemon, groups, nodes, querybuilder, server, submit, users + + +def create_app() -> FastAPI: + """Create the FastAPI application and include the routers. + + :return: The FastAPI application. + :rtype: FastAPI + """ + + read_only = os.getenv('AIIDA_RESTAPI_READ_ONLY') == '1' + + app = FastAPI() + + api_router = APIRouter(prefix=API_CONFIG['PREFIX']) + + api_router.add_route( + '/', + lambda _: RedirectResponse(url=api_router.url_path_for('endpoints')), + ) + + for module in (auth, server, users, computers, groups, nodes, querybuilder, submit, daemon): + if read_router := getattr(module, 'read_router', None): + api_router.include_router(read_router) + if not read_only and (write_router := getattr(module, 'write_router', None)): + api_router.include_router(write_router) + + api_router.add_route( + '/graphql', + main.app, + methods=['POST'], + ) + + app.include_router(api_router) + + return app diff --git a/aiida_restapi/models.py b/aiida_restapi/models.py deleted file mode 100644 index ecd291ec..00000000 --- a/aiida_restapi/models.py +++ /dev/null @@ -1,340 +0,0 @@ -"""Schemas for AiiDA REST API. - -Models in this module mirror those in -`aiida.backends.djsite.db.models` and `aiida.backends.sqlalchemy.models` -""" -# pylint: disable=too-few-public-methods - -import inspect -from datetime import datetime -from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar -from uuid import UUID - -from aiida import orm -from fastapi import Form -from pydantic import BaseModel, ConfigDict, Field - -# Template type for subclasses of `AiidaModel` -ModelType = TypeVar('ModelType', bound='AiidaModel') - - -def as_form(cls: Type[BaseModel]) -> Type[BaseModel]: - """ - Adds an as_form class method to decorated models. The as_form class method - can be used with FastAPI endpoints - - Note: Taken from https://github.com/tiangolo/fastapi/issues/2387 - """ - new_parameters = [] - - for field_name, model_field in cls.model_fields.items(): - new_parameters.append( - inspect.Parameter( - name=field_name, - kind=inspect.Parameter.POSITIONAL_ONLY, - default=Form(...) if model_field.is_required() else Form(model_field.default), - annotation=model_field.annotation, - ) - ) - - async def as_form_func(**data: Dict[str, Any]) -> Any: - return cls(**data) - - sig = inspect.signature(as_form_func) - sig = sig.replace(parameters=new_parameters) - as_form_func.__signature__ = sig # type: ignore - setattr(cls, 'as_form', as_form_func) - return cls - - -class AiidaModel(BaseModel): - """A mapping of an AiiDA entity to a pydantic model.""" - - _orm_entity: ClassVar[Type[orm.entities.Entity]] = orm.entities.Entity - model_config = ConfigDict(from_attributes=True, extra='forbid') - - @classmethod - def get_projectable_properties(cls) -> List[str]: - """Return projectable properties.""" - return list(cls.schema()['properties'].keys()) - - @classmethod - def get_entities( - cls: Type[ModelType], - *, - page_size: Optional[int] = None, - page: int = 0, - project: Optional[List[str]] = None, - order_by: Optional[List[str]] = None, - ) -> List[ModelType]: - """Return a list of entities (with pagination). - - :param project: properties to project (default: all available) - :param page_size: the page size (default: infinite) - :param page: the page to return, if page_size set - """ - if project is None: - project = cls.get_projectable_properties() - else: - assert set(cls.get_projectable_properties()).issuperset( - project - ), f'projection not subset of projectable properties: {project!r}' - query = orm.QueryBuilder().append(cls._orm_entity, tag='fields', project=project) - if page_size is not None: - query.offset(page_size * (page - 1)) - query.limit(page_size) - if order_by is not None: - assert set(cls.get_projectable_properties()).issuperset( - order_by - ), f'order_by not subset of projectable properties: {project!r}' - query.order_by({'fields': order_by}) - return [cls(**result['fields']) for result in query.dict()] - - -class Comment(AiidaModel): - """AiiDA Comment model.""" - - _orm_entity = orm.Comment - - id: Optional[int] = Field(None, description='Unique comment id (pk)') - uuid: str = Field(description='Unique comment uuid') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - content: Optional[str] = Field(None, description='Comment content') - dbnode_id: Optional[int] = Field(None, description='Unique node id (pk)') - user_id: Optional[int] = Field(None, description='Unique user id (pk)') - - -class User(AiidaModel): - """AiiDA User model.""" - - _orm_entity = orm.User - model_config = ConfigDict(extra='allow') - - id: Optional[int] = Field(None, description='Unique user id (pk)') - email: str = Field(description='Email address of the user') - first_name: Optional[str] = Field(None, description='First name of the user') - last_name: Optional[str] = Field(None, description='Last name of the user') - institution: Optional[str] = Field(None, description='Host institution or workplace of the user') - - -class Computer(AiidaModel): - """AiiDA Computer Model.""" - - _orm_entity = orm.Computer - - id: Optional[int] = Field(None, description='Unique computer id (pk)') - uuid: Optional[str] = Field(None, description='Unique id for computer') - label: str = Field(description='Used to identify a computer. Must be unique') - hostname: Optional[str] = Field(None, description='Label that identifies the computer within the network') - scheduler_type: Optional[str] = Field( - None, - description='The scheduler (and plugin) that the computer uses to manage jobs', - ) - transport_type: Optional[str] = Field( - None, - description='The transport (and plugin) \ - required to copy files and communicate to and from the computer', - ) - metadata: Optional[dict] = Field( - None, - description='General settings for these communication and management protocols', - ) - - description: Optional[str] = Field(None, description='Description of node') - - -class Node(AiidaModel): - """AiiDA Node Model.""" - - _orm_entity = orm.Node - - id: Optional[int] = Field(None, description='Unique id (pk)') - uuid: Optional[UUID] = Field(None, description='Unique uuid') - node_type: Optional[str] = Field(None, description='Node type') - process_type: Optional[str] = Field(None, description='Process type') - label: str = Field(description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - repository_metadata: Optional[Dict] = Field( - None, - description='Metadata about file repository associated with this node', - ) - - -@as_form -class Node_Post(AiidaModel): - """AiiDA model for posting Nodes.""" - - entry_point: str = Field(description='Entry_point') - process_type: Optional[str] = Field(None, description='Process type') - label: Optional[str] = Field(None, description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - - @classmethod - def create_new_node( - cls: Type[ModelType], - orm_class: orm.Node, - node_dict: dict, - ) -> orm.Node: - """Create and Store new Node""" - attributes = node_dict.pop('attributes', {}) - extras = node_dict.pop('extras', {}) - repository_metadata = node_dict.pop('repository_metadata', {}) - - if issubclass(orm_class, orm.BaseType): - orm_object = orm_class( - attributes['value'], - **node_dict, - ) - elif issubclass(orm_class, orm.Dict): - orm_object = orm_class( - dict=attributes, - **node_dict, - ) - elif issubclass(orm_class, orm.InstalledCode): - orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), - filepath_executable=attributes['filepath_executable'], - ) - orm_object.label = node_dict.get('label') - elif issubclass(orm_class, orm.PortableCode): - orm_object = orm_class( - computer=orm.load_computer(pk=node_dict.get('dbcomputer_id')), - filepath_executable=attributes['filepath_executable'], - filepath_files=attributes['filepath_files'], - ) - orm_object.label = node_dict.get('label') - else: - orm_object = orm_class(**node_dict) - orm_object.base.attributes.set_many(attributes) - - orm_object.base.extras.set_many(extras) - orm_object.base.repository.repository_metadata = repository_metadata - orm_object.store() - return orm_object - - @classmethod - def create_new_node_with_file( - cls: Type[ModelType], - orm_class: orm.Node, - node_dict: dict, - file: Path, - ) -> orm.Node: - """Create and Store new Node with file""" - attributes = node_dict.pop('attributes', {}) - extras = node_dict.pop('extras', {}) - - orm_object = orm_class(file=file, **node_dict, **attributes) - - orm_object.base.extras.set_many(extras) - orm_object.store() - return orm_object - - -class Group(AiidaModel): - """AiiDA Group model.""" - - _orm_entity = orm.Group - - id: int = Field(description='Unique id (pk)') - uuid: UUID = Field(description='Universally unique id') - label: str = Field(description='Label of group') - type_string: str = Field(description='type of the group') - description: Optional[str] = Field(None, description='Description of group') - extras: Optional[Dict] = Field(None, description='extra data about for the group') - time: datetime = Field(description='Created time') - user_id: int = Field(description='Created by user id (pk)') - - @classmethod - def from_orm(cls, orm_entity: orm.Group) -> orm.Group: - """Convert from ORM object. - - Args: - obj: The ORM entity to convert - - Returns: - The converted Group object - """ - query = ( - orm.QueryBuilder() - .append( - cls._orm_entity, - filters={'pk': orm_entity.id}, - tag='fields', - project=['user_id', 'time'], - ) - .limit(1) - ) - orm_entity.user_id = query.dict()[0]['fields']['user_id'] - - return super().from_orm(orm_entity) - - -class Group_Post(AiidaModel): - """AiiDA Group Post model.""" - - _orm_entity = orm.Group - - label: str = Field(description='Used to access the group. Must be unique.') - type_string: Optional[str] = Field(None, description='Type of the group') - description: Optional[str] = Field(None, description='Short description of the group.') - - -class Process(AiidaModel): - """AiiDA Process Model""" - - _orm_entity = orm.ProcessNode - - id: Optional[int] = Field(None, description='Unique id (pk)') - uuid: Optional[UUID] = Field(None, description='Universally unique identifier') - node_type: Optional[str] = Field(None, description='Node type') - process_type: Optional[str] = Field(None, description='Process type') - label: str = Field(description='Label of node') - description: Optional[str] = Field(None, description='Description of node') - ctime: Optional[datetime] = Field(None, description='Creation time') - mtime: Optional[datetime] = Field(None, description='Last modification time') - user_id: Optional[int] = Field(None, description='Created by user id (pk)') - dbcomputer_id: Optional[int] = Field(None, description='Associated computer id (pk)') - attributes: Optional[Dict] = Field( - None, - description='Variable attributes of the node', - ) - extras: Optional[Dict] = Field( - None, - description='Variable extras (unsealed) of the node', - ) - repository_metadata: Optional[Dict] = Field( - None, - description='Metadata about file repository associated with this node', - ) - - -class Process_Post(AiidaModel): - """AiiDA Process Post Model""" - - label: str = Field(description='Label of node') - inputs: dict = Field(description='Input parmeters') - process_entry_point: str = Field(description='Entry Point for process') diff --git a/aiida_restapi/models/__init__.py b/aiida_restapi/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_restapi/models/node.py b/aiida_restapi/models/node.py new file mode 100644 index 00000000..0f8096ae --- /dev/null +++ b/aiida_restapi/models/node.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida.orm import Node +from aiida.plugins import get_entry_points +from importlib_metadata import EntryPoint + + +class NodeStatistics(pdt.BaseModel): + """Pydantic model representing node statistics.""" + + total: int = pdt.Field( + description='Total number of nodes.', + examples=[47], + ) + types: dict[str, int] = pdt.Field( + description='Number of nodes by type.', + examples=[ + { + 'data.core.int.Int.': 42, + 'data.core.singlefile.SinglefileData.': 5, + } + ], + ) + ctime_by_day: dict[str, int] = pdt.Field( + description='Number of nodes created per day (YYYY-MM-DD).', + examples=[ + { + '2012-01-01': 10, + '2012-01-02': 15, + } + ], + ) + + +class NodeType(pdt.BaseModel): + """Pydantic model representing a node type.""" + + label: str = pdt.Field( + description='The class name of the node type.', + examples=['Int'], + ) + node_type: str = pdt.Field( + description='The AiiDA node type string.', + examples=['data.core.int.Int.'], + ) + nodes: str = pdt.Field( + description='The URL to access nodes of this type.', + examples=['../nodes?filters={"node_type":{"data.core.int.Int."}}'], + ) + projections: str = pdt.Field( + description='The URL to access projectable properties of this node type.', + examples=['../nodes/projections?type=data.core.int.Int.'], + ) + node_schema: str = pdt.Field( + description='The URL to access the schema of this node type.', + examples=['../nodes/schema?type=data.core.int.Int.'], + ) + + +class RepoFileMetadata(pdt.BaseModel): + """Pydantic model representing the metadata of a file in the AiiDA repository.""" + + type: t.Literal['FILE'] = pdt.Field( + description='The type of the repository object.', + examples=['FILE'], + ) + binary: bool = pdt.Field( + False, + description='Whether the file is binary.', + examples=[True], + ) + size: int = pdt.Field( + description='The size of the file in bytes.', + examples=[1024], + ) + download: str = pdt.Field( + description='The URL to download the file.', + examples=['../nodes/{uuid}/repo/contents?filename=path/to/file.txt'], + ) + + +class RepoDirMetadata(pdt.BaseModel): + """Pydantic model representing the metadata of a directory in the AiiDA repository.""" + + type: t.Literal['DIRECTORY'] = pdt.Field( + description='The type of the repository object.', + examples=['DIRECTORY'], + ) + objects: dict[str, t.Union[RepoFileMetadata, 'RepoDirMetadata']] = pdt.Field( + description='A dictionary with the metadata of the objects in the directory.', + examples=[ + { + 'file.txt': { + 'type': 'FILE', + 'binary': False, + 'size': 2048, + 'download': '../nodes/{uuid}/repo/contents?filename=path/to/file.txt', + }, + 'subdir': { + 'type': 'DIRECTORY', + 'objects': {}, + }, + } + ], + ) + + +MetadataType = t.Union[RepoFileMetadata, RepoDirMetadata] + + +class NodeLink(Node.Model): + link_label: str = pdt.Field(description='The label of the link to the node.') + link_type: str = pdt.Field(description='The type of the link to the node.') + + +class NodeModelRegistry: + """Registry for AiiDA REST API node models. + + This class maintains mappings of node types and their corresponding Pydantic models. + + :ivar ModelUnion: A union type of all AiiDA node Pydantic models, discriminated by the `node_type` field. + """ + + def __init__(self) -> None: + self._build_node_mappings() + self.ModelUnion = t.Annotated[ + t.Union[self._get_post_models()], + pdt.Field(discriminator='node_type'), + ] + + def get_node_types(self) -> list[str]: + """Get the list of registered node class names. + + :return: List of node class names. + """ + return list(self._models.keys()) + + def get_node_class_name(self, node_type: str) -> str: + """Get the AiiDA node class name for a given node type. + + :param node_type: The AiiDA node type string. + :return: The corresponding node class name. + """ + return node_type.rsplit('.', 2)[-2] + + def get_model(self, node_type: str, which: t.Literal['get', 'post'] = 'get') -> type[Node.Model]: + """Get the Pydantic model class for a given node type. + + :param node_type: The AiiDA node type string. + :return: The corresponding Pydantic model class. + """ + if (Model := self._models.get(node_type)) is None: + raise KeyError(f'Unknown node type: {node_type}') + if which not in Model: + raise KeyError(f'Unknown model type: {which}') + return Model[which] + + def _get_node_post_model(self, node_cls: Node) -> type[Node.Model]: + """Return a patched Model for the given node class with a literal `node_type` field. + + :param node_cls: The AiiDA node class. + :return: The patched ORM Node model. + """ + Model = node_cls.CreateModel + node_type = node_cls.class_node_type + # Here we patch in the `node_type` union descriminator field. + # We annotate it with `SkipJsonSchema` to keep it off the public openAPI schema. + Model.model_fields['node_type'] = pdt.fields.FieldInfo( + annotation=pdt.json_schema.SkipJsonSchema[t.Literal[node_type]], # type: ignore[misc,valid-type] + default=node_type, + ) + Model.model_rebuild(force=True) + return t.cast(type[Node.Model], Model) + + def _build_node_mappings(self) -> None: + """Build mapping of node type to node creation model.""" + self._models: dict[str, dict[str, type[Node.Model]]] = {} + entry_point: EntryPoint + for entry_point in get_entry_points('aiida.data'): + try: + node_cls = t.cast(Node, entry_point.load()) + except Exception as exception: + # Skip entry points that cannot be loaded + print(f'Warning: could not load entry point {entry_point.name}: {exception}') + continue + + self._models[node_cls.class_node_type] = { + 'get': node_cls.Model, + 'post': self._get_node_post_model(node_cls), + } + + def _get_post_models(self) -> tuple[type[Node.Model], ...]: + """Get a union type of all node 'post' models. + + :return: A union type of all node 'post' models. + """ + post_models = [model_dict['post'] for model_dict in self._models.values()] + return tuple(post_models) diff --git a/aiida_restapi/resources.py b/aiida_restapi/resources.py deleted file mode 100644 index 7713b3b9..00000000 --- a/aiida_restapi/resources.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Union - -from aiida.common.exceptions import EntryPointError, LoadingEntryPointError -from aiida.plugins.entry_point import get_entry_point_names, load_entry_point - -from aiida_restapi.exceptions import RestFeatureNotAvailable, RestInputValidationError -from aiida_restapi.identifiers import construct_full_type, load_entry_point_from_full_type - - -def get_all_download_formats(full_type: Union[str, None] = None) -> dict: - """Returns dict of possible node formats for all available node types""" - all_formats = {} - - if full_type: - try: - node_cls = load_entry_point_from_full_type(full_type) - except (TypeError, ValueError): - raise RestInputValidationError(f'The full type {full_type} is invalid.') - except EntryPointError: - raise RestFeatureNotAvailable('The download formats for this node type are not available.') - - try: - available_formats = node_cls.get_export_formats() - all_formats[full_type] = available_formats - except AttributeError: - pass - else: - entry_point_group = 'aiida.data' - - for name in get_entry_point_names(entry_point_group): - try: - node_cls = load_entry_point(entry_point_group, name) - available_formats = node_cls.get_export_formats() - except (AttributeError, LoadingEntryPointError): - continue - - if available_formats: - full_type = construct_full_type(node_cls.class_node_type, '') - all_formats[full_type] = available_formats - - return all_formats diff --git a/aiida_restapi/routers/auth.py b/aiida_restapi/routers/auth.py index a6b60bc4..4a371542 100644 --- a/aiida_restapi/routers/auth.py +++ b/aiida_restapi/routers/auth.py @@ -1,10 +1,12 @@ """Handle API authentication and authorization.""" -# pylint: disable=missing-function-docstring,missing-class-docstring -from datetime import datetime, timedelta -from typing import Any, Dict, Optional +from __future__ import annotations + +import typing as t +from datetime import datetime, timedelta, timezone import bcrypt +from aiida import orm from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from fastapi import APIRouter, Depends, HTTPException, status @@ -13,7 +15,6 @@ from pydantic import BaseModel from aiida_restapi import config -from aiida_restapi.models import User class Token(BaseModel): @@ -25,16 +26,17 @@ class TokenData(BaseModel): email: str -class UserInDB(User): +class UserInDB(orm.User.Model): hashed_password: str - disabled: Optional[bool] = None + disabled: t.Optional[bool] = None pwd_context = PasswordHasher() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f'{config.API_CONFIG["PREFIX"]}/auth/token') -router = APIRouter() +read_router = APIRouter(prefix='/auth') +write_router = APIRouter(prefix='/auth') def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -54,14 +56,14 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) -def get_user(db: dict, email: str) -> Optional[UserInDB]: +def get_user(db: dict, email: str) -> UserInDB | None: if email in db: user_dict = db[email] return UserInDB(**user_dict) return None -def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[UserInDB]: +def authenticate_user(fake_db: dict, email: str, password: str) -> UserInDB | None: user = get_user(fake_db, email) if not user: @@ -73,18 +75,18 @@ def authenticate_user(fake_db: dict, email: str, password: str) -> Optional[User return user -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: +def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str: to_encode = data.copy() if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=15) + expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode.update({'exp': expire}) encoded_jwt = jwt.encode(to_encode, config.SECRET_KEY, algorithm=config.ALGORITHM) return encoded_jwt -async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: +async def get_current_user(token: t.Annotated[str, Depends(oauth2_scheme)]) -> orm.User.Model: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail='Could not validate credentials', @@ -92,7 +94,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: ) try: payload = jwt.decode(token, config.SECRET_KEY, algorithms=[config.ALGORITHM]) - email: str = payload.get('sub') + email = payload.get('sub') if email is None: raise credentials_exception token_data = TokenData(email=email) @@ -105,17 +107,21 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: async def get_current_active_user( - current_user: UserInDB = Depends(get_current_user), + current_user: t.Annotated[UserInDB, Depends(get_current_user)], ) -> UserInDB: if current_user.disabled: raise HTTPException(status_code=400, detail='Inactive user') return current_user -@router.post('/token', response_model=Token) +@write_router.post( + '/token', + response_model=Token, +) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), -) -> Dict[str, Any]: +) -> dict[str, t.Any]: + """Login to get access token.""" user = authenticate_user(config.fake_users_db, form_data.username, form_data.password) if not user: raise HTTPException( @@ -128,12 +134,12 @@ async def login_for_access_token( return {'access_token': access_token, 'token_type': 'bearer'} -@router.get('/auth/me/', response_model=User) -async def read_users_me(current_user: User = Depends(get_current_active_user)) -> User: +@read_router.get( + '/me/', + response_model=orm.User.Model, +) +async def read_users_me( + current_user: t.Annotated[orm.User.Model, Depends(get_current_active_user)], +) -> orm.User.Model: + """Get the current authenticated user.""" return current_user - - -# @router.get('/users/me/items/') -# async def read_own_items( -# current_user: User = Depends(get_current_active_user)): -# return [{'item_id': 'Foo', 'owner': current_user.email}] diff --git a/aiida_restapi/routers/computers.py b/aiida_restapi/routers/computers.py index b97d4666..92b212de 100644 --- a/aiida_restapi/routers/computers.py +++ b/aiida_restapi/routers/computers.py @@ -1,53 +1,113 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for computers.""" -from typing import List, Optional +from __future__ import annotations + +import typing as t from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from aiida.orm.querybuilder import QueryBuilder -from fastapi import APIRouter, Depends - -from aiida_restapi.models import Computer, User - -from .auth import get_current_active_user - -router = APIRouter() +from aiida.common.exceptions import NotExistent +from fastapi import APIRouter, Depends, HTTPException, Query +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.services.entity import EntityService -@router.get('/computers', response_model=List[Computer]) -@with_dbenv() -async def read_computers() -> List[Computer]: - """Get list of all computers""" - - return Computer.get_entities() +from .auth import UserInDB, get_current_active_user +read_router = APIRouter(prefix='/computers') +write_router = APIRouter(prefix='/computers') -@router.get('/computers/projectable_properties', response_model=List[str]) -async def get_computers_projectable_properties() -> List[str]: - """Get projectable properties for computers endpoint""" +service = EntityService[orm.Computer, orm.Computer.Model](orm.Computer) - return Computer.get_projectable_properties() - -@router.get('/computers/{comp_id}', response_model=Computer) +@read_router.get( + '/schema', + response_model=dict, +) +async def get_computers_schema( + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve: "get" or "post"', + ), +) -> dict: + """Get JSON schema for AiiDA computers.""" + try: + return service.get_schema(which=which) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +async def get_computer_projections() -> list[str]: + """Get queryable projections for AiiDA computers.""" + return service.get_projections() + + +@read_router.get( + '', + response_model=PaginatedResults[orm.Computer.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() -async def read_computer(comp_id: int) -> Optional[Computer]: - """Get computer by id.""" - qbobj = QueryBuilder() - qbobj.append(orm.Computer, filters={'id': comp_id}, project='**', tag='computer').limit(1) - - return qbobj.dict()[0]['computer'] - - -@router.post('/computers', response_model=Computer) +async def get_computers( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Computer.Model]: + """Get AiiDA computers with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/{pk}', + response_model=orm.Computer.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_computer(pk: str) -> orm.Computer.Model: + """Get AiiDA computer by pk.""" + try: + return service.get_one(pk) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{pk}/metadata', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_computer_metadata(pk: str) -> dict[str, t.Any]: + """Get metadata of an AiiDA computer by pk.""" + try: + return service.get_field(pk, 'metadata') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '', + response_model=orm.Computer.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() async def create_computer( - computer: Computer, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Computer: + computer_model: orm.Computer.CreateModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Computer.Model: """Create new AiiDA computer.""" - orm_computer = orm.Computer(**computer.dict(exclude_unset=True)).store() - - return Computer.from_orm(orm_computer) + try: + return service.add_one(computer_model) + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) diff --git a/aiida_restapi/routers/daemon.py b/aiida_restapi/routers/daemon.py index 8f2f4258..9965d67a 100644 --- a/aiida_restapi/routers/daemon.py +++ b/aiida_restapi/routers/daemon.py @@ -9,10 +9,10 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field -from ..models import User -from .auth import get_current_active_user +from .auth import UserInDB, get_current_active_user -router = APIRouter() +read_router = APIRouter(prefix='/daemon') +write_router = APIRouter(prefix='/daemon') class DaemonStatusModel(BaseModel): @@ -22,7 +22,10 @@ class DaemonStatusModel(BaseModel): num_workers: t.Optional[int] = Field(description='The number of workers if the daemon is running.') -@router.get('/daemon/status', response_model=DaemonStatusModel) +@read_router.get( + '/status', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_status() -> DaemonStatusModel: """Return the daemon status.""" @@ -31,17 +34,21 @@ async def get_daemon_status() -> DaemonStatusModel: if not client.is_daemon_running: return DaemonStatusModel(running=False, num_workers=None) - response = client.get_numprocesses() + try: + response = client.get_numprocesses() + except DaemonException as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post('/daemon/start', response_model=DaemonStatusModel) +@write_router.post( + '/start', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_start( - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], ) -> DaemonStatusModel: """Start the daemon.""" client = get_daemon_client() @@ -51,20 +58,20 @@ async def get_daemon_start( try: client.start_daemon() + response = client.get_numprocesses() except DaemonException as exception: raise HTTPException(status_code=500, detail=str(exception)) from exception - response = client.get_numprocesses() - return DaemonStatusModel(running=True, num_workers=response['numprocesses']) -@router.post('/daemon/stop', response_model=DaemonStatusModel) +@write_router.post( + '/stop', + response_model=DaemonStatusModel, +) @with_dbenv() async def get_daemon_stop( - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], ) -> DaemonStatusModel: """Stop the daemon.""" client = get_daemon_client() diff --git a/aiida_restapi/routers/groups.py b/aiida_restapi/routers/groups.py index e7e017e6..32cd5ddc 100644 --- a/aiida_restapi/routers/groups.py +++ b/aiida_restapi/routers/groups.py @@ -1,51 +1,113 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for groups.""" -from typing import List, Optional +from __future__ import annotations + +import typing as t from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from fastapi import APIRouter, Depends - -from aiida_restapi.models import Group, Group_Post, User - -from .auth import get_current_active_user - -router = APIRouter() - - -@router.get('/groups', response_model=List[Group]) -@with_dbenv() -async def read_groups() -> List[Group]: - """Get list of all groups""" +from aiida.common.exceptions import NotExistent +from fastapi import APIRouter, Depends, HTTPException, Query - return Group.get_entities() +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.services.entity import EntityService +from .auth import UserInDB, get_current_active_user -@router.get('/groups/projectable_properties', response_model=List[str]) -async def get_groups_projectable_properties() -> List[str]: - """Get projectable properties for groups endpoint""" +read_router = APIRouter(prefix='/groups') +write_router = APIRouter(prefix='/groups') - return Group.get_projectable_properties() +service = EntityService[orm.Group, orm.Group.Model](orm.Group) -@router.get('/groups/{group_id}', response_model=Group) +@read_router.get( + '/schema', + response_model=dict, +) +async def get_groups_schema( + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve: "get" or "post"', + ), +) -> dict: + """Get JSON schema for AiiDA groups.""" + try: + return service.get_schema(which=which) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +async def get_group_projections() -> list[str]: + """Get queryable projections for AiiDA groups.""" + return service.get_projections() + + +@read_router.get( + '', + response_model=PaginatedResults[orm.Group.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() -async def read_group(group_id: int) -> Optional[Group]: - """Get group by id.""" - qbobj = orm.QueryBuilder() - - qbobj.append(orm.Group, filters={'id': group_id}, project='**', tag='group').limit(1) - return qbobj.dict()[0]['group'] - - -@router.post('/groups', response_model=Group) +async def get_groups( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Group.Model]: + """Get AiiDA groups with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/{uuid}', + response_model=orm.Group.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_group(uuid: str) -> orm.Group.Model: + """Get AiiDA group by uuid.""" + try: + return service.get_one(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{uuid}/extras', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_group_extras(uuid: str) -> dict[str, t.Any]: + """Get the extras of a group.""" + try: + return service.get_field(uuid, 'extras') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '', + response_model=orm.Group.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() async def create_group( - group: Group_Post, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Group: + group_model: orm.Group.CreateModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Group.Model: """Create new AiiDA group.""" - orm_group = orm.Group(**group.dict(exclude_unset=True)).store() - return Group.from_orm(orm_group) + try: + return service.add_one(group_model) + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index 5f17a4e5..baef5b2c 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -1,57 +1,239 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for nodes.""" +from __future__ import annotations + +import io import json -import os -import tempfile -from pathlib import Path -from typing import Any, Generator, List, Optional +import typing as t +import pydantic as pdt from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent -from aiida.plugins.entry_point import load_entry_point -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi import APIRouter, Depends, Form, HTTPException, Query, UploadFile from fastapi.responses import StreamingResponse -from pydantic import ValidationError +from typing_extensions import TypeAlias + +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.config import API_CONFIG +from aiida_restapi.models.node import MetadataType, NodeModelRegistry, NodeStatistics, NodeType +from aiida_restapi.services.node import NodeLink, NodeService + +from .auth import UserInDB, get_current_active_user + +read_router = APIRouter(prefix='/nodes') +write_router = APIRouter(prefix='/nodes') + +service = NodeService[orm.Node, orm.Node.Model](orm.Node) -from aiida_restapi import models, resources +model_registry = NodeModelRegistry() -from .auth import get_current_active_user +if t.TYPE_CHECKING: + # Dummy type for static analysis + NodeModelUnion: TypeAlias = pdt.BaseModel +else: + # The real discriminated union built at runtime + NodeModelUnion = model_registry.ModelUnion -router = APIRouter() + +@read_router.get( + '/schema', + response_model=dict, +) +async def get_nodes_schema( + node_type: str | None = Query( + None, + description='The AiiDA node type string.', + alias='type', + ), + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve', + ), +) -> dict: + """Get JSON schema for the base AiiDA node 'get' model.""" + if not node_type: + return orm.Node.Model.model_json_schema() + try: + model = model_registry.get_model(node_type, which) + return model.model_json_schema() + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/projections', + response_model=list[str], +) +@with_dbenv() +async def get_node_projections( + node_type: str | None = Query( + None, + description='The AiiDA node type string.', + alias='type', + ), +) -> list[str]: + """Get queryable projections for AiiDA nodes.""" + try: + return service.get_projections(node_type) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes', response_model=List[models.Node]) +@read_router.get( + '/statistics', + response_model=NodeStatistics, +) @with_dbenv() -async def read_nodes() -> List[models.Node]: - """Get list of all nodes""" - return models.Node.get_entities() +async def get_nodes_statistics(user: int | None = None) -> dict[str, t.Any]: + """Get node statistics.""" + + from aiida.manage import get_manager + + backend = get_manager().get_profile_storage() + return backend.query().get_creation_statistics(user_pk=user) + + +@read_router.get('/download_formats') +async def get_nodes_download_formats() -> dict[str, t.Any]: + """Get download formats for AiiDA nodes.""" + try: + return service.get_download_formats() + except EntryPointError as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/projectable_properties', response_model=List[str]) -async def get_nodes_projectable_properties() -> List[str]: - """Get projectable properties for nodes endpoint""" +@read_router.get( + '', + response_model=PaginatedResults[orm.Node.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_nodes( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.Node.Model]: + """Get AiiDA nodes with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) + + +@read_router.get( + '/types', + response_model=list[NodeType], +) +async def get_node_types() -> list: + """Get all node types in machine-actionable format.""" + api_prefix = API_CONFIG['PREFIX'] + return [ + { + 'label': model_registry.get_node_class_name(node_type), + 'node_type': node_type, + 'nodes': f'{api_prefix}/nodes?filters={{"node_type":"{node_type}"}}', + 'projections': f'{api_prefix}/nodes/projections?type={node_type}', + 'node_schema': f'{api_prefix}/nodes/schema?type={node_type}', + } + for node_type in sorted( + model_registry.get_node_types(), key=lambda node_type: model_registry.get_node_class_name(node_type) + ) + ] - return models.Node.get_projectable_properties() +@read_router.get( + '/{uuid}', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_node(uuid: str) -> orm.Node.Model: + """Get AiiDA node by uuid.""" + try: + return service.get_one(uuid) + except NotExistent as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/download_formats', response_model=dict[str, Any]) -async def get_nodes_download_formats() -> dict[str, Any]: - """Get download formats for nodes endpoint""" - return resources.get_all_download_formats() +@read_router.get( + '/{uuid}/attributes', + response_model=dict[str, t.Any], +) +@with_dbenv() +async def get_node_attributes(uuid: str) -> dict[str, t.Any]: + """Get the attributes of a node.""" + try: + return service.get_field(uuid, 'attributes') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/nodes/{nodes_id}/download') +@read_router.get( + '/{uuid}/extras', + response_model=dict[str, t.Any], +) @with_dbenv() -async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse: - """Get nodes by id.""" - from aiida.orm import load_node +async def get_node_extras(uuid: str) -> dict[str, t.Any]: + """Get the extras of a node.""" + try: + return service.get_field(uuid, 'extras') + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + +@read_router.get( + '/{uuid}/links', + response_model=PaginatedResults[NodeLink], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def get_node_links( + uuid: str, + queries: t.Annotated[QueryParams, Depends(query_params)], + direction: t.Literal['incoming', 'outgoing'] = Query( + description='Specify whether to retrieve incoming or outgoing links.', + ), +) -> PaginatedResults[NodeLink]: + """Get the incoming or outgoing links of a node.""" try: - node = load_node(nodes_id) - except NotExistent: - raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}') + return service.get_links(uuid, queries, direction=direction) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@read_router.get( + '/{uuid}/download', + response_class=StreamingResponse, +) +@with_dbenv() +async def download_node( + uuid: str, + download_format: str | None = Query( + None, + description='Format to download the node in', + ), +) -> StreamingResponse: + """Download AiiDA node by uuid in a given download format provided as a query parameter.""" + try: + node = orm.load_node(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception if download_format is None: raise HTTPException( @@ -63,14 +245,13 @@ async def download_node(nodes_id: int, download_format: Optional[str] = None) -> elif download_format in node.get_export_formats(): # byteobj, dict with {filename: filecontent} - import io try: exported_bytes, _ = node._exportcontent(download_format) - except LicensingException as exc: - raise HTTPException(status_code=500, detail=str(exc)) + except LicensingException as exception: + raise HTTPException(status_code=403, detail=str(exception)) from exception - def stream() -> Generator[bytes, None, None]: + def stream() -> t.Generator[bytes, None, None]: with io.BytesIO(exported_bytes) as handler: yield from handler @@ -85,91 +266,134 @@ def stream() -> Generator[bytes, None, None]: ) -@router.get('/nodes/{nodes_id}', response_model=models.Node) +@read_router.get( + '/{uuid}/repo/metadata', + response_model=dict[str, MetadataType], +) @with_dbenv() -async def read_node(nodes_id: int) -> Optional[models.Node]: - """Get nodes by id.""" - qbobj = orm.QueryBuilder() - qbobj.append(orm.Node, filters={'id': nodes_id}, project='**', tag='node').limit(1) - return qbobj.dict()[0]['node'] +async def get_node_repo_file_metadata(uuid: str) -> dict[str, dict]: + """Get the repository file metadata of a node.""" + try: + return service.get_repository_metadata(uuid) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.post('/nodes', response_model=models.Node) +@read_router.get( + '/{uuid}/repo/contents', + response_class=StreamingResponse, +) @with_dbenv() -async def create_node( - node: models.Node_Post, - current_user: models.User = Depends( # pylint: disable=unused-argument - get_current_active_user +async def get_node_repo_file_contents( + uuid: str, + filename: str | None = Query( + None, + description='Filename of repository content to retrieve', ), -) -> models.Node: - """Create new AiiDA node.""" - node_dict = node.dict(exclude_unset=True) - entry_point = node_dict.pop('entry_point', None) +) -> StreamingResponse: + """Get the repository contents of a node.""" + from urllib.parse import quote try: - cls = load_entry_point(group='aiida.data', name=entry_point) - except EntryPointError as exception: + node = orm.load_node(uuid) + except NotExistent as exception: raise HTTPException(status_code=404, detail=str(exception)) from exception - try: - orm_object = models.Node_Post.create_new_node(cls, node_dict) - except (TypeError, ValueError, KeyError) as exception: - raise HTTPException(status_code=400, detail=str(exception)) from exception + repo = node.base.repository - return models.Node.from_orm(orm_object) + if filename: + try: + file_content = repo.get_object_content(filename, mode='rb') + except FileNotFoundError as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + def file_stream() -> t.Generator[bytes, None, None]: + with io.BytesIO(file_content) as handler: + yield from handler -@router.post('/nodes/singlefile', response_model=models.Node) -@with_dbenv() -async def create_upload_file( - params: str = Form(...), - upload_file: UploadFile = File(...), - current_user: models.User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> models.Node: - """Endpoint for uploading file data + download_name = filename.split('/')[-1] or 'download' + quoted = quote(download_name) + headers = {'Content-Disposition': f"attachment; filename={download_name!r}; filename*=UTF-8''{quoted}"} + + return StreamingResponse(file_stream(), media_type='application/octet-stream', headers=headers) + + else: + zip_bytes = repo.get_zipped_objects() + + def zip_stream() -> t.Generator[bytes, None, None]: + with io.BytesIO(zip_bytes) as handler: + yield from handler - Note that in this multipart form case, json input can't be used. - Get the parameters as a string and manually pass through pydantic. - """ + download_name = f'node_{uuid}_repo.zip' + quoted = quote(download_name) + headers = {'Content-Disposition': f"attachment; filename={download_name!r}; filename*=UTF-8''{quoted}"} + + return StreamingResponse(zip_stream(), media_type='application/zip', headers=headers) + + +@write_router.post( + '', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def create_node( + model: NodeModelUnion, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Node.Model: + """Create new AiiDA node.""" + try: + return service.add_one(model) + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception + + +@write_router.post( + '/file-upload', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def create_node_with_files( + params: t.Annotated[str, Form()], + files: list[UploadFile], + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Node.Model: + """Create new AiiDA node with files.""" try: - # Parse the JSON string into a dictionary - params_dict = json.loads(params) - # Validate against the Pydantic model - params_obj = models.Node_Post(**params_dict) + parameters = t.cast(dict, json.loads(params)) except json.JSONDecodeError as exception: - raise HTTPException( - status_code=400, - detail=f'Invalid JSON format: {exception!s}', - ) from exception - except ValidationError as exception: - raise HTTPException( - status_code=422, - detail=f'Validation failed: {exception}', - ) from exception + raise HTTPException(400, str(exception)) from exception - node_dict = params_obj.dict(exclude_unset=True) - entry_point = node_dict.pop('entry_point', None) + if not (node_type := parameters.get('node_type')): + raise HTTPException(422, "Missing 'node_type' in params") try: - cls = load_entry_point(group='aiida.data', name=entry_point) - except EntryPointError as exception: - raise HTTPException( - status_code=404, - detail=f'Could not load entry point: {exception}', - ) from exception + model_cls = model_registry.get_model(node_type, which='post') + model = model_cls(**parameters) + except KeyError as exception: + raise HTTPException(422, str(exception)) from exception + except pdt.ValidationError as exception: + raise HTTPException(422, str(exception)) from exception - with tempfile.NamedTemporaryFile(mode='wb', delete=False) as temp_file: - # Todo: read in chunks - content = await upload_file.read() - temp_file.write(content) - temp_path = temp_file.name + files_dict: dict[str, UploadFile] = {} - orm_object = models.Node_Post.create_new_node_with_file(cls, node_dict, Path(temp_path)) + for upload in files: + if (target := upload.filename) in files_dict: + raise HTTPException(422, f"Duplicate target path '{target}' in upload") + files_dict[target] = upload - # Clean up the temporary file - if os.path.exists(temp_path): - os.unlink(temp_path) - - return models.Node.from_orm(orm_object) + try: + return service.add_one(model, files=files_dict) + except json.JSONDecodeError as exception: + raise HTTPException(status_code=400, detail=str(exception)) from exception + except KeyError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception diff --git a/aiida_restapi/routers/process.py b/aiida_restapi/routers/process.py deleted file mode 100644 index 905a468f..00000000 --- a/aiida_restapi/routers/process.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Declaration of FastAPI router for processes.""" - -from typing import List, Optional - -from aiida import orm -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import NotExistent -from aiida.engine import submit -from aiida.orm.querybuilder import QueryBuilder -from aiida.plugins import load_entry_point_from_string -from fastapi import APIRouter, Depends, HTTPException - -from aiida_restapi.models import Process, Process_Post, User - -from .auth import get_current_active_user - -router = APIRouter() - - -def process_inputs(inputs: dict) -> dict: - """Process the inputs dictionary converting each node UUID into the corresponding node by loading it. - - A node UUID is indicated by the key ending with the suffix ``.uuid``. - - :param inputs: The inputs dictionary. - :returns: The deserialized inputs dictionary. - :raises HTTPException: If the inputs contain a UUID that does not correspond to an existing node. - """ - uuid_suffix = '.uuid' - results = {} - - for key, value in inputs.items(): - if isinstance(value, dict): - results[key] = process_inputs(value) - elif key.endswith(uuid_suffix): - try: - results[key[: -len(uuid_suffix)]] = orm.load_node(uuid=value) - except NotExistent as exc: - raise HTTPException( - status_code=404, - detail=f'Node with UUID `{value}` does not exist.', - ) from exc - else: - results[key] = value - - return results - - -@router.get('/processes', response_model=List[Process]) -@with_dbenv() -async def read_processes() -> List[Process]: - """Get list of all processes""" - - return Process.get_entities() - - -@router.get('/processes/projectable_properties', response_model=List[str]) -async def get_processes_projectable_properties() -> List[str]: - """Get projectable properties for processes endpoint""" - - return Process.get_projectable_properties() - - -@router.get('/processes/{proc_id}', response_model=Process) -@with_dbenv() -async def read_process(proc_id: int) -> Optional[Process]: - """Get process by id.""" - qbobj = QueryBuilder() - qbobj.append(orm.ProcessNode, filters={'id': proc_id}, project='**', tag='process').limit(1) - - return qbobj.dict()[0]['process'] - - -@router.post('/processes', response_model=Process) -@with_dbenv() -async def post_process( - process: Process_Post, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> Optional[Process]: - """Create new process.""" - process_dict = process.dict(exclude_unset=True, exclude_none=True) - inputs = process_inputs(process_dict['inputs']) - entry_point = process_dict.get('process_entry_point') - - try: - entry_point_process = load_entry_point_from_string(entry_point) - except ValueError as exc: - raise HTTPException( - status_code=404, - detail=f"Entry point '{entry_point}' not recognized.", - ) from exc - - process_node = submit(entry_point_process, **inputs) - - return process_node diff --git a/aiida_restapi/routers/querybuilder.py b/aiida_restapi/routers/querybuilder.py new file mode 100644 index 00000000..b8df2bc3 --- /dev/null +++ b/aiida_restapi/routers/querybuilder.py @@ -0,0 +1,158 @@ +"""Declaration of FastAPI router for AiiDA's QueryBuilder.""" + +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida import orm +from aiida.cmdline.utils.decorators import with_dbenv +from fastapi import APIRouter, HTTPException, Query + +from aiida_restapi.common.pagination import PaginatedResults + +read_router = APIRouter(prefix='/querybuilder') + + +class QueryBuilderPathItem(pdt.BaseModel): + """Pydantic model for QueryBuilder path items.""" + + entity_type: str | list[str] = pdt.Field( + description='The AiiDA entity type.', + ) + orm_base: str = pdt.Field( + description='The ORM base class of the entity.', + ) + tag: str | None = pdt.Field( + None, + description='An optional tag for the path item.', + ) + joining_keyword: str | None = pdt.Field( + None, + description='The joining keyword for relationships (e.g., "input", "output").', + ) + joining_value: str | None = pdt.Field( + None, + description='The joining value for relationships (e.g., "input", "output").', + ) + edge_tag: str | None = pdt.Field( + None, + description='An optional tag for the edge.', + ) + outerjoin: bool = pdt.Field( + False, + description='Whether to perform an outer join.', + ) + + +class QueryBuilderDict(pdt.BaseModel): + """Pydantic model for QueryBuilder POST requests.""" + + path: list[str | QueryBuilderPathItem] = pdt.Field( + description='The QueryBuilder path as a list of entity types or path items.', + examples=[ + [ + ['data.core.int.Int.', 'data.core.float.Float.'], + { + 'entity_type': 'data.core.int.Int.', + 'orm_base': 'node', + 'tag': 'integers', + }, + { + 'entity_type': ['data.core.int.Int.', 'data.core.float.Float.'], + 'orm_base': 'node', + 'tag': 'numbers', + }, + ] + ], + ) + filters: dict[str, dict[str, t.Any]] | None = pdt.Field( + None, + description='The QueryBuilder filters as a dictionary mapping tags to filter conditions.', + examples=[ + { + 'integers': {'attributes.value': {'<': 42}}, + } + ], + ) + project: dict[str, str | list[str]] | None = pdt.Field( + None, + description='The QueryBuilder projection as a dictionary mapping tags to attributes to project.', + examples=[ + { + 'integers': ['uuid', 'attributes.value'], + } + ], + ) + limit: pdt.NonNegativeInt | None = pdt.Field( + 10, + description='The maximum number of results to return.', + examples=[5], + ) + offset: pdt.NonNegativeInt | None = pdt.Field( + 0, + description='The number of results to skip before starting to collect the result set.', + examples=[0], + ) + order_by: str | list[str] | dict[str, t.Any] | None = pdt.Field( + None, + description='The QueryBuilder order_by as a string, list of strings, ' + 'or dictionary mapping tags to order conditions.', + examples=[ + {'integers': {'pk': 'desc'}}, + ], + ) + distinct: bool = pdt.Field( + False, + description='Whether to return only distinct results.', + examples=[False], + ) + + +@read_router.post( + '', + response_model=PaginatedResults, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def query_builder( + query: QueryBuilderDict, + flat: bool = Query( + False, + description='Whether to return results flat.', + ), +) -> PaginatedResults: + """Execute a QueryBuilder query based on the provided dictionary.""" + query_dict = query.model_dump() + + try: + limit = query_dict.pop('limit', 10) + offset = query_dict.get('offset', 0) + + # Get total count before applying the limit + qb = orm.QueryBuilder.from_dict(query_dict) + total = qb.count() + qb.limit(limit) + + # Run query builder + project = t.cast(dict, query_dict.get('project')) + if not project or any(p in ('*', ['*']) for p in project.values()): + # Projecting entities as entity models + return PaginatedResults[orm.Entity.Model]( + total=total, + page=offset // limit + 1, + page_size=limit, + results=[entity.to_model(minimal=True) for entity in t.cast(list[orm.Entity], qb.all(flat=True))], + ) + else: + # Projecting specific attributes + return PaginatedResults[t.Any]( + total=total, + page=offset // limit + 1, + page_size=limit, + results=qb.all(flat=flat), + ) + + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/aiida_restapi/routers/server.py b/aiida_restapi/routers/server.py new file mode 100644 index 00000000..c9fcf507 --- /dev/null +++ b/aiida_restapi/routers/server.py @@ -0,0 +1,227 @@ +"""Declaration of FastAPI application.""" + +from __future__ import annotations + +import re + +import pydantic as pdt +from aiida import __version__ as aiida_version +from fastapi import APIRouter, Request +from fastapi.responses import HTMLResponse +from fastapi.routing import APIRoute +from starlette.routing import Route + +from aiida_restapi.config import API_CONFIG + +read_router = APIRouter(prefix='/server') + + +class ServerInfo(pdt.BaseModel): + """API version information.""" + + API_major_version: str = pdt.Field( + description='Major version of the API', + examples=['0'], + ) + API_minor_version: str = pdt.Field( + description='Minor version of the API', + examples=['1'], + ) + API_revision_version: str = pdt.Field( + description='Revision version of the API', + examples=['0a1'], + ) + API_prefix: str = pdt.Field( + description='Prefix for all API endpoints', + examples=['/api/v0'], + ) + AiiDA_version: str = pdt.Field( + description='Version of the AiiDA installation', + examples=['2.7.2.post0'], + ) + + +@read_router.get( + '/info', + response_model=ServerInfo, +) +async def get_server_info() -> ServerInfo: + """Get the API version information.""" + api_version = API_CONFIG['VERSION'].split('.') + return ServerInfo( + API_major_version=api_version[0], + API_minor_version=api_version[1], + API_revision_version=api_version[2], + API_prefix=API_CONFIG['PREFIX'], + AiiDA_version=aiida_version, + ) + + +class ServerEndpoint(pdt.BaseModel): + """API endpoint.""" + + path: str = pdt.Field( + description='Path of the endpoint', + examples=['../server/endpoints'], + ) + group: str | None = pdt.Field( + description='Group of the endpoint', + examples=['server'], + ) + methods: set[str] = pdt.Field( + description='HTTP methods supported by the endpoint', + examples=['GET'], + ) + description: str = pdt.Field( + '-', + description='Description of the endpoint', + examples=['Get a JSON-serializable dictionary of all registered API routes.'], + ) + + +@read_router.get( + '/endpoints', + response_model=dict[str, list[ServerEndpoint]], +) +async def get_server_endpoints(request: Request) -> dict[str, list[ServerEndpoint]]: + """Get a JSON-serializable dictionary of all registered API routes.""" + endpoints: list[ServerEndpoint] = [] + + for route in request.app.routes: + if route.path == '/': + continue + + group, methods, description = _get_route_parts(route) + base_url = str(request.base_url).rstrip('/') + + endpoint = { + 'path': base_url + route.path, + 'group': group, + 'methods': methods, + 'description': description, + } + + endpoints.append(ServerEndpoint(**endpoint)) + + return {'endpoints': endpoints} + + +@read_router.get( + '/endpoints/table', + name='endpoints', + response_class=HTMLResponse, +) +async def get_server_endpoints_table(request: Request) -> HTMLResponse: + """Get an HTML table of all registered API routes.""" + routes = request.app.routes + base_url = str(request.base_url).rstrip('/') + + rows = [] + + for route in routes: + if route.path == '/': + continue + + path = base_url + route.path + group, methods, description = _get_route_parts(route) + + disable_url = ( + ( + isinstance(route, APIRoute) + and any( + param + for param in route.dependant.path_params + + route.dependant.query_params + + route.dependant.body_params + if param.required + ) + ) + or (route.methods and 'POST' in route.methods) + or 'auth' in path + ) + + path_row = path if disable_url else f'{path}' + + rows.append(f""" + + {path_row} + {group or '-'} + {', '.join(methods)} + {description or '-'} + + """) + + return HTMLResponse( + content=f""" + + + AiiDA REST API Endpoints + + + +

AiiDA REST API Endpoints

+ + + + + + + + + + + {''.join(rows)} + +
URLGroupMethodsDescription
+ + + """ + ) + + +def _get_route_parts(route: Route) -> tuple[str | None, set[str], str]: + """Return the parts of a route: path, group, methods, description. + + :param route: A FastAPI/Starlette Route object. + :return: A tuple containing the group, methods, and description of the route. + """ + prefix = re.escape(API_CONFIG['PREFIX']) + match = re.match(rf'^{prefix}/([^/]+)/?.*', route.path) + group = match.group(1) if match else None + methods = (route.methods or set()) - {'HEAD', 'OPTIONS'} + description = (route.endpoint.__doc__ or '').split('\n')[0].strip() + return group, methods, description diff --git a/aiida_restapi/routers/submit.py b/aiida_restapi/routers/submit.py new file mode 100644 index 00000000..49fa0b21 --- /dev/null +++ b/aiida_restapi/routers/submit.py @@ -0,0 +1,90 @@ +"""Declaration of FastAPI router for submission.""" + +from __future__ import annotations + +import typing as t + +import pydantic as pdt +from aiida import engine, orm +from aiida.cmdline.utils.decorators import with_dbenv +from aiida.common.exceptions import NotExistent +from aiida.plugins.entry_point import load_entry_point_from_string +from fastapi import APIRouter, Depends, HTTPException + +from .auth import UserInDB, get_current_active_user + +write_router = APIRouter(prefix='/submit') + + +def process_inputs(inputs: dict[str, t.Any]) -> dict[str, t.Any]: + """Process the inputs dictionary converting each node UUID into the corresponding node by loading it. + + A node UUID is indicated by the key ending with the suffix ``.uuid``. + + :param inputs: The inputs dictionary. + :returns: The deserialized inputs dictionary. + :raises HTTPException: 404 if the inputs contain a UUID that does not correspond to an existing node. + """ + uuid_suffix = '.uuid' + results = {} + + for key, value in inputs.items(): + if isinstance(value, dict): + results[key] = process_inputs(value) + elif key.endswith(uuid_suffix): + try: + results[key[: -len(uuid_suffix)]] = orm.load_node(uuid=value) + except NotExistent as exc: + raise HTTPException(status_code=404, detail=f'Node with UUID `{value}` does not exist.') from exc + else: + results[key] = value + + return results + + +class ProcessSubmitModel(pdt.BaseModel): + label: str = pdt.Field( + '', + description='The label of the process', + examples=['My process', 'Test calculation'], + ) + entry_point: str = pdt.Field( + description='The entry point of the process', + examples=['core.arithmetic.add'], + ) + inputs: dict[str, t.Any] = pdt.Field( + description='The inputs of the process', + examples=[{'x': 1, 'y': 2}], + ) + + @pdt.field_validator('inputs') + @classmethod + def validate_inputs(cls, inputs: dict[str, t.Any]) -> dict[str, t.Any]: + """Process the inputs dictionary. + + :param inputs: The inputs to validate. + :returns: The validated inputs. + """ + return process_inputs(inputs) + + +@write_router.post( + '', + response_model=orm.Node.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) +@with_dbenv() +async def submit_process( + process: ProcessSubmitModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.Node.Model: + """Submit new AiiDA process.""" + try: + entry_point_process = load_entry_point_from_string(process.entry_point) + process_node = engine.submit(entry_point_process, **process.inputs) + return t.cast(orm.Node.Model, process_node.to_model()) + except ValueError as err: + raise HTTPException(status_code=404, detail=f"Entry point '{process.entry_point}' not recognized.") from err + except Exception as err: + raise HTTPException(status_code=500, detail=str(err)) from err diff --git a/aiida_restapi/routers/users.py b/aiida_restapi/routers/users.py index 2277f514..122ffc11 100644 --- a/aiida_restapi/routers/users.py +++ b/aiida_restapi/routers/users.py @@ -1,51 +1,96 @@ -"""Declaration of FastAPI application.""" +"""Declaration of FastAPI router for users.""" -from typing import List, Optional +from __future__ import annotations + +import typing as t from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from aiida.orm.querybuilder import QueryBuilder -from fastapi import APIRouter, Depends +from aiida.common.exceptions import NotExistent +from fastapi import APIRouter, Depends, HTTPException, Query -from aiida_restapi.models import User +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams, query_params +from aiida_restapi.services.entity import EntityService -from .auth import get_current_active_user +from .auth import UserInDB, get_current_active_user -router = APIRouter() +read_router = APIRouter(prefix='/users') +write_router = APIRouter(prefix='/users') +service = EntityService[orm.User, orm.User.Model](orm.User) -@router.get('/users', response_model=List[User]) -@with_dbenv() -async def read_users() -> List[User]: - """Get list of all users""" - return User.get_entities() +@read_router.get( + '/schema', + response_model=dict, +) +async def get_users_schema( + which: t.Literal['get', 'post'] = Query( + 'get', + description='Type of schema to retrieve: "get" or "post"', + ), +) -> dict: + """Get JSON schema for AiiDA users.""" + try: + return service.get_schema(which=which) + except ValueError as exception: + raise HTTPException(status_code=422, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.get('/users/projectable_properties', response_model=List[str]) -async def get_users_projectable_properties() -> List[str]: - """Get projectable properties for users endpoint""" - return User.get_projectable_properties() +@read_router.get( + '/projections', + response_model=list[str], +) +async def get_user_projections() -> list[str]: + """Get queryable projections for AiiDA users.""" + return service.get_projections() -@router.get('/users/{user_id}', response_model=User) +@read_router.get( + '', + response_model=PaginatedResults[orm.User.Model], + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() -async def read_user(user_id: int) -> Optional[User]: - """Get user by id.""" - qbobj = QueryBuilder() - qbobj.append(orm.User, filters={'id': user_id}, project='**', tag='user').limit(1) +async def get_users( + queries: t.Annotated[QueryParams, Depends(query_params)], +) -> PaginatedResults[orm.User.Model]: + """Get AiiDA users with optional filtering, sorting, and/or pagination.""" + return service.get_many(queries) - return qbobj.dict()[0]['user'] + +@read_router.get( + '/{pk}', + response_model=orm.User.Model, +) +@with_dbenv() +async def get_user(pk: int) -> orm.User.Model: + """Get AiiDA user by pk.""" + try: + return service.get_one(pk) + except NotExistent as exception: + raise HTTPException(status_code=404, detail=str(exception)) from exception + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) from exception -@router.post('/users', response_model=User) +@write_router.post( + '', + response_model=orm.User.Model, + response_model_exclude_none=True, + response_model_exclude_unset=True, +) @with_dbenv() async def create_user( - user: User, - current_user: User = Depends( # pylint: disable=unused-argument - get_current_active_user - ), -) -> User: + user_model: orm.User.CreateModel, + current_user: t.Annotated[UserInDB, Depends(get_current_active_user)], +) -> orm.User.Model: """Create new AiiDA user.""" - orm_user = orm.User(**user.dict(exclude_unset=True)).store() - return User.from_orm(orm_user) + try: + return service.add_one(user_model) + except Exception as exception: + raise HTTPException(status_code=500, detail=str(exception)) diff --git a/aiida_restapi/services/__init__.py b/aiida_restapi/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aiida_restapi/services/entity.py b/aiida_restapi/services/entity.py new file mode 100644 index 00000000..947a06fd --- /dev/null +++ b/aiida_restapi/services/entity.py @@ -0,0 +1,118 @@ +"""REST API entity repository.""" + +from __future__ import annotations + +import typing as t + +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams +from aiida_restapi.common.types import EntityModelType, EntityType + + +class EntityService(t.Generic[EntityType, EntityModelType]): + """Utility class for AiiDA REST API operations. + + This class provides methods to retrieve AiiDA entities with optional filtering, sorting, and pagination. + + :param entity_class: The AiiDA ORM entity class associated with this utility, e.g. `orm.User`, `orm.Node`, etc. + """ + + def __init__(self, entity_class: type[EntityType]) -> None: + self.entity_class = entity_class + + def get_schema(self, which: t.Literal['get', 'post'] | None = None) -> dict: + """Get JSON schema for the AiiDA entity. + + :param which: The type of schema to retrieve: 'get' or 'post'. + :type which: str | None + :return: A dictionary with 'get' and 'post' keys containing the respective JSON schemas. + :rtype: dict + :raises ValueError: If the 'which' parameter is not 'get' or 'post'. + """ + if not which: + return { + 'get': self.entity_class.Model.model_json_schema(), + 'post': self.entity_class.CreateModel.model_json_schema(), + } + elif which == 'get': + return self.entity_class.Model.model_json_schema() + elif which == 'post': + return self.entity_class.CreateModel.model_json_schema() + raise ValueError(f'Schema type "{which}" not supported; expected "get" or "post"') + + def get_projections(self) -> list[str]: + """Get queryable projections for the AiiDA entity. + + :return: The list of queryable projections for the AiiDA entity. + :rtype: list[str] + """ + return self.entity_class.fields.keys() + + def get_many(self, queries: QueryParams) -> PaginatedResults[EntityModelType]: + """Get AiiDA entities with optional filtering, sorting, and/or pagination. + + :param queries: The query parameters, including filters, order_by, page_size, and page. + :type queries: QueryParams + :return: The paginated results, including total count, current page, page size, and list of entity models. + :rtype: PaginatedResults[EntityModelType] + """ + total = self.entity_class.collection.count(filters=queries.filters) + results = self.entity_class.collection.find( + filters=queries.filters, + order_by=queries.order_by, + limit=queries.page_size, + offset=queries.page_size * (queries.page - 1), + ) + return PaginatedResults( + total=total, + page=queries.page, + page_size=queries.page_size, + results=[self._to_model(result) for result in results], + ) + + def get_one(self, identifier: str | int) -> EntityModelType: + """Get an AiiDA entity by id. + + :param identifier: The id of the entity to retrieve. + :type identifier: str | int + :return: The AiiDA entity model, e.g. `orm.User.Model`, `orm.Node.Model`, etc. + :rtype: EntityModelType + """ + entity = self.entity_class.collection.get(**{self.entity_class.identity_field: identifier}) + return self._to_model(entity) + + def get_field(self, identifier: str | int, field: str) -> t.Any: + """Get a specific field of an entity. + + :param identifier: The id of the entity to retrieve the extras for. + :type identifier: str | int + :param field: The specific field to retrieve. + :type field: str + :return: The value of the specified field. + :rtype: t.Any + """ + return self.entity_class.collection.query( + filters={self.entity_class.identity_field: identifier}, + project=[field], + ).first()[0] + + def add_one(self, model: EntityModelType) -> EntityModelType: + """Create new AiiDA entity from its model. + + :param model: The Pydantic model of the entity to create. + :type model: EntityModelType + :return: The created and stored AiiDA `Entity` instance. + :rtype: EntityModelType + """ + entity = self.entity_class.from_model(model).store() + return self._to_model(entity) + + def _to_model(self, entity: EntityType) -> EntityModelType: + """Convert an AiiDA entity to its Pydantic model. + + :param entity: The AiiDA entity to convert. + :type entity: EntityType + :return: The Pydantic model of the entity, excluding any fields specified in `excluded_fields`. + :rtype: EntityModelType + """ + return t.cast(EntityModelType, entity.to_model(minimal=True)) diff --git a/aiida_restapi/services/node.py b/aiida_restapi/services/node.py new file mode 100644 index 00000000..68d95adb --- /dev/null +++ b/aiida_restapi/services/node.py @@ -0,0 +1,354 @@ +"""REST API node repository.""" + +from __future__ import annotations + +import typing as t + +from aiida.common import EntryPointError +from aiida.common.escaping import escape_for_sql_like +from aiida.common.exceptions import DbContentError, LoadingEntryPointError +from aiida.common.lang import type_check +from aiida.orm.utils import load_node_class +from aiida.plugins.entry_point import ( + get_entry_point_names, + is_valid_entry_point_string, + load_entry_point, + load_entry_point_from_string, +) +from aiida.repository import File + +from aiida_restapi.common.pagination import PaginatedResults +from aiida_restapi.common.query import QueryParams +from aiida_restapi.common.types import NodeModelType, NodeType +from aiida_restapi.models.node import NodeLink + +from .entity import EntityService + +if t.TYPE_CHECKING: + from fastapi import UploadFile + + +class NodeService(EntityService[NodeType, NodeModelType]): + """Utility class for AiiDA Node REST API operations.""" + + FULL_TYPE_CONCATENATOR = '|' + LIKE_OPERATOR_CHARACTER = '%' + DEFAULT_NAMESPACE_LABEL = '~no-entry-point~' + + def get_projections(self, node_type: str | None = None) -> list[str]: + """Get projectable properties for the AiiDA entity. + + :param node_type: The AiiDA node type. + :type node_type: str | None + :return: The list of projectable properties for the AiiDA node. + :rtype: list[str] + """ + if not node_type: + return super().get_projections() + else: + node_cls = self._load_entry_point_from_node_type(node_type) + return sorted(node_cls.fields.keys()) + + def get_download_formats(self, full_type: str | None = None) -> dict: + """Returns dict of possible node formats for all available node types. + + :param full_type: The full type of the AiiDA node. + :type full_type: str | None + :return: A dictionary with full types as keys and list of available formats as values. + :rtype: dict[str, list[str]] + """ + all_formats = {} + + if full_type: + node_cls = self._load_entry_point_from_full_type(full_type) + try: + available_formats = node_cls.get_export_formats() + all_formats[full_type] = available_formats + except AttributeError: + pass + else: + entry_point_group = 'aiida.data' + + for name in get_entry_point_names(entry_point_group): + try: + node_cls = load_entry_point(entry_point_group, name) + available_formats = node_cls.get_export_formats() + except (AttributeError, LoadingEntryPointError): + continue + + if available_formats: + full_type = self._construct_full_type(node_cls.class_node_type, '') + all_formats[full_type] = available_formats + + return all_formats + + def get_repository_metadata(self, uuid: str) -> dict[str, dict]: + """Get the repository metadata of a node. + + :param uuid: The uuid of the node to retrieve the repository metadata for. + :type uuid: str + :return: A dictionary with the repository file metadata. + :rtype: dict[str, dict] + """ + node = self.entity_class.collection.get(uuid=uuid) + total_size = 0 + + def get_metadata(objects: list[File], path: str | None = None) -> dict[str, dict]: + nonlocal total_size + + content: dict = {} + + for obj in objects: + obj_name = f'{path}/{obj.name}' if path else obj.name + if obj.is_dir(): + content[obj.name] = { + 'type': 'DIRECTORY', + 'objects': get_metadata( + node.base.repository.list_objects(obj_name), + obj_name, + ), + } + elif obj.is_file(): + size = node.base.repository.get_object_size(obj_name) + + binary = False + try: + with node.base.repository.open(obj_name, 'rb') as f: + binary = b'\x00' in f.read(8192) + f.seek(0) + except (UnicodeDecodeError, TypeError): + binary = True + + content[obj.name] = { + 'type': 'FILE', + 'binary': binary, + 'size': size, + 'download': f'/nodes/{uuid}/repo/contents?filename={obj_name}', + } + total_size += size + + return content + + metadata = get_metadata(node.base.repository.list_objects()) + + if total_size: + metadata['zipped'] = { + 'type': 'FILE', + 'binary': True, + 'size': total_size, + 'download': f'/nodes/{uuid}/repo/contents', + } + + return metadata + + def get_links( + self, + uuid: str, + queries: QueryParams, + direction: t.Literal['incoming', 'outgoing'], + ) -> PaginatedResults[NodeLink]: + """Get the incoming links of a node. + + :param uuid: The uuid of the node to retrieve the incoming links for. + :type uuid: str + :param queries: The query parameters, including filters, order_by, page_size, and page. + :type queries: QueryParams + :param direction: Specify whether to retrieve incoming or outgoing links. + :type direction: str + :return: The paginated requested linked nodes. + :rtype: PaginatedResults[NodeLink] + """ + node = self.entity_class.collection.get(uuid=uuid) + + if direction == 'incoming': + link_collection = node.base.links.get_incoming() + else: + link_collection = node.base.links.get_outgoing() + + all_links = link_collection.all() + + start, end = ( + queries.page_size * (queries.page - 1), + queries.page_size * queries.page, + ) + + links = [ + NodeLink( + **link.node.serialize(minimal=True), + link_label=link.link_label, + link_type=link.link_type.value, + ) + for link in all_links[start:end] + ] + + return PaginatedResults( + total=len(all_links), + page=queries.page, + page_size=queries.page_size, + results=links, + ) + + def add_one( + self, + model: NodeModelType, + files: dict[str, UploadFile] | None = None, + ) -> NodeModelType: + """Create new AiiDA node from its model. + + :param node_model: The AiiDA ORM model of the node to create. + :type model: NodeModelType + :param files: Optional list of files to attach to the node. + :type files: dict[str, UploadFile] | None + :return: The created and stored AiiDA node instance. + :rtype: NodeModelType + """ + node_cls = self._load_entry_point_from_node_type(model.node_type) + node = t.cast(NodeType, node_cls.from_model(model)) + for path, upload in (files or {}).items(): + upload.file.seek(0) + node.base.repository.put_object_from_filelike(upload.file, path) + node.store() + return self._to_model(node) + + def _validate_full_type(self, full_type: str) -> None: + """Validate that the `full_type` is a valid full type unique node identifier. + + :param full_type: a `Node` full type + :type full_type: str + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + """ + + type_check(full_type, str) + + if self.FULL_TYPE_CONCATENATOR not in full_type: + raise ValueError( + f'full type `{full_type}` does not include the required concatenator symbol ' + f'`{self.FULL_TYPE_CONCATENATOR}`.' + ) + elif full_type.count(self.FULL_TYPE_CONCATENATOR) > 1: + raise ValueError( + f'full type `{full_type}` includes the concatenator symbol ' + f'`{self.FULL_TYPE_CONCATENATOR}` more than once.' + ) + + def _construct_full_type(self, node_type: str, process_type: str) -> str: + """Return the full type, which fully identifies the type of any `Node` with the given `node_type` and + `process_type`. + + :param node_type: the `node_type` of the `Node` + :type node_type: str + :param process_type: the `process_type` of the `Node` + :type process_type: str + :return: the full type, which is a unique identifier + :rtype: str + """ + if node_type is None: + node_type = '' + + if process_type is None: + process_type = '' + + return f'{node_type}{self.FULL_TYPE_CONCATENATOR}{process_type}' + + def _get_full_type_filters(self, full_type: str) -> dict[str, t.Any]: + """Return the `QueryBuilder` filters that will return all `Nodes` identified by the given `full_type`. + + :param full_type: the `full_type` node type identifier + :type full_type: str + :return: dictionary of filters to be passed for the `filters` keyword in `QueryBuilder.append` + :rtype: dict[str, t.Any] + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + """ + self._validate_full_type(full_type) + + filters: dict[str, t.Any] = {} + node_type, process_type = full_type.split(self.FULL_TYPE_CONCATENATOR) + + for entry in (node_type, process_type): + if entry.count(self.LIKE_OPERATOR_CHARACTER) > 1: + raise ValueError(f'full type component `{entry}` contained more than one like-operator character') + + if self.LIKE_OPERATOR_CHARACTER in entry and entry[-1] != self.LIKE_OPERATOR_CHARACTER: + raise ValueError(f'like-operator character in full type component `{entry}` is not at the end') + + if self.LIKE_OPERATOR_CHARACTER in node_type: + # Remove the trailing `LIKE_OPERATOR_CHARACTER`, escape the string and reattach the character + node_type = node_type[:-1] + node_type = escape_for_sql_like(node_type) + self.LIKE_OPERATOR_CHARACTER + filters['node_type'] = {'like': node_type} + else: + filters['node_type'] = {'==': node_type} + + if self.LIKE_OPERATOR_CHARACTER in process_type: + # Remove the trailing `LIKE_OPERATOR_CHARACTER` () + # If that was the only specification, just ignore this filter (looking for any process_type) + # If there was more: escape the string and reattach the character + process_type = process_type[:-1] + if process_type: + process_type = escape_for_sql_like(process_type) + self.LIKE_OPERATOR_CHARACTER + filters['process_type'] = {'like': process_type} + elif process_type: + filters['process_type'] = {'==': process_type} + else: + # A `process_type=''` is used to represents both `process_type='' and `process_type=None`. + # This is because there is no simple way to single out null `process_types`, and therefore + # we consider them together with empty-string process_types. + # Moreover, the existence of both is most likely a bug of migrations and thus both share + # this same "erroneous" origin. + filters['process_type'] = {'or': [{'==': ''}, {'==': None}]} + + return filters + + def _load_entry_point_from_node_type(self, node_type: str) -> NodeType: + """Return the loaded entry point for the given `node_type`. + + :param node_type: the `node_type` unique node type identifier + :type node_type: str + :return: the loaded entry point + :rtype: NodeType + :raises ValueError: if the `node_type` is invalid + """ + try: + return t.cast(NodeType, load_node_class(node_type)) + except DbContentError as exception: + raise ValueError(f'invalid node type `{node_type}`') from exception + + def _load_entry_point_from_full_type(self, full_type: str) -> t.Any: + """Return the loaded entry point for the given `full_type` unique node type identifier. + + :param full_type: the `full_type` unique node type identifier + :type full_type: str + :return: the loaded entry point + :rtype: t.Any + :raises ValueError: if the `full_type` is invalid + :raises TypeError: if the `full_type` is not a string type + :raises `~aiida.common.exceptions.EntryPointError`: if the corresponding entry point cannot be loaded + """ + + data_prefix = 'data.' + + self._validate_full_type(full_type) + + node_type, process_type = full_type.split(self.FULL_TYPE_CONCATENATOR) + + if is_valid_entry_point_string(process_type): + try: + return load_entry_point_from_string(process_type) + except EntryPointError: + raise EntryPointError(f'could not load entry point `{process_type}`') + + elif node_type.startswith(data_prefix): + base_name = node_type.removeprefix(data_prefix) + entry_point_name = base_name.rsplit('.', 2)[0] + + try: + return load_entry_point('aiida.data', entry_point_name) + except EntryPointError: + raise EntryPointError(f'could not load entry point `{process_type}`') + + # Here we are dealing with a `ProcessNode` with a `process_type` that is not an entry point string. + # Which means it is most likely a full module path (the fallback option) and we cannot necessarily load the + # class from this. We could try with `importlib` but not sure that we should + raise EntryPointError('entry point of the given full type cannot be loaded') diff --git a/docs/source/conf.py b/docs/source/conf.py index 80ebe50f..0eb03539 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -71,20 +71,40 @@ } autodoc_typehints = 'none' nitpick_ignore = [ - ('py:class', name) - for name in [ - 'pydantic.main.BaseModel', - 'pydantic.types.Json', - 'graphene.types.generic.GenericScalar', - 'graphene.types.objecttype.ObjectType', - 'graphene.types.scalars.String', - 'aiida_restapi.aiida_db_mappings.Config', - 'aiida_restapi.models.Config', - 'aiida_restapi.routers.auth.Config', - 'aiida_restapi.graphql.orm_factories.AiidaOrmObjectType', - 'aiida_restapi.graphql.nodes.LinkObjectType', - 'aiida_restapi.graphql.orm_factories.multirow_cls_factory..AiidaOrmRowsType', - ] + *[ + ('py:class', name) + for name in [ + 'pydantic.main.BaseModel', + 'pydantic.types.Json', + 'graphene.types.generic.GenericScalar', + 'graphene.types.objecttype.ObjectType', + 'graphene.types.scalars.String', + 'graphene.types.objecttype.ObjectTypeMeta.__new__..InterObjectType', + 'graphene.types.objecttype.AiidaOrmObjectType', + 'graphene.types.scalars.Scalar', + 'aiida.orm.users.User.Model', + 'aiida.common.exceptions.FeatureNotAvailable', + 'aiida.common.exceptions.InputValidationError', + 'aiida.common.exceptions.EntryPointError', + 'aiida_restapi.aiida_db_mappings.Config', + 'aiida_restapi.models.Config', + 'aiida_restapi.routers.auth.Config', + 'aiida_restapi.routers.nodes.NodeType', + 'aiida_restapi.graphql.orm_factories.AiidaOrmObjectType', + 'aiida_restapi.graphql.nodes.LinkObjectType', + 'aiida_restapi.graphql.orm_factories.multirow_cls_factory..AiidaOrmRowsType', + ] + ], + *[ + ('py:obj', name) + for name in [ + 'aiida_restapi.common.types.EntityModelType', + 'aiida_restapi.common.types.EntityType', + 'aiida_restapi.common.types.NodeType', + 'aiida_restapi.common.types.NodeModelType', + ] + ], + ('py:exc', 'HTTPException'), ] suppress_warnings = ['etoc.toctree'] diff --git a/pyproject.toml b/pyproject.toml index 5a7aea6b..4f9c430d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,14 +16,16 @@ classifiers = [ 'Topic :: Scientific/Engineering', ] dependencies = [ - 'aiida-core~=2.5', + 'aiida-core @ git+https://github.com/edan-bainglass/aiida-core.git@fix-node-serialization', 'fastapi~=0.115.5', 'uvicorn[standard]~=0.32.1', 'pydantic~=2.0', + 'eval_type_backport', 'starlette-graphene3~=0.6.0', 'graphene~=3.0', 'python-dateutil~=2.0', 'lark~=0.11.0', + 'python-multipart~=0.0.20', ] dynamic = ['description', 'version'] keywords = ['aiida', 'workflows'] @@ -55,8 +57,12 @@ testing = [ 'httpx~=0.27.2', 'numpy~=1.21', 'anyio~=4.6.0', + 'beautifulsoup4~=4.12', ] +[project.scripts] +aiida-restapi = 'aiida_restapi.cli.main:cli' + [project.urls] Source = 'https://github.com/aiidateam/aiida-restapi' diff --git a/tests/conftest.py b/tests/conftest.py index 162fe557..6f01672c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ """Test fixtures specific to this package.""" -# pylint: disable=too-many-arguments +import os import tempfile from datetime import datetime from typing import Any, Callable, Mapping, MutableMapping, Optional, Union @@ -12,15 +12,49 @@ from aiida.common.exceptions import NotExistent from aiida.engine import ProcessState from aiida.orm import WorkChainNode, WorkFunctionNode +from fastapi import FastAPI from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request -from aiida_restapi import app, config +from aiida_restapi import config +from aiida_restapi.config import API_CONFIG +from aiida_restapi.main import create_app from aiida_restapi.routers.auth import UserInDB, get_current_user pytest_plugins = ['aiida.tools.pytest_fixtures'] +class PrefixMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI, prefix: str = ''): + super().__init__(app) + self.prefix = prefix or API_CONFIG['PREFIX'] + + async def dispatch(self, request: Request, call_next): + if not request.url.path.startswith(self.prefix): + request.scope['path'] = self.prefix + request.url.path + return await call_next(request) + + +@pytest.fixture(scope='session') +def app(): + """Return fastapi app.""" + app = create_app() + app.add_middleware(PrefixMiddleware) + yield app + + +@pytest.fixture(scope='function') +def read_only_app(): + """Return fastapi app.""" + os.environ['AIIDA_RESTAPI_READ_ONLY'] = '1' + app = create_app() + app.add_middleware(PrefixMiddleware) + yield app + os.environ['AIIDA_RESTAPI_READ_ONLY'] = '0' + + @pytest.fixture(scope='session', autouse=True) def aiida_profile(aiida_config, aiida_profile_factory): """Create and load a profile with RabbitMQ as broker.""" @@ -34,7 +68,7 @@ def clear_database_auto(aiida_profile_clean): # pylint: disable=unused-argument @pytest.fixture(scope='function') -def client(): +def client(app): """Return fastapi test client.""" yield TestClient(app) @@ -50,7 +84,7 @@ def anyio_backend(): @pytest.fixture(scope='function') -async def async_client(): +async def async_client(app): """Return fastapi async test client.""" async with AsyncClient(transport=ASGITransport(app=app), base_url='http://test') as async_test_client: yield async_test_client @@ -103,7 +137,7 @@ def example_processes(): calc.base.attributes.set('process_label', process_label) calc.store() - calcs.append(calc.pk) + calcs.append(calc.uuid) calc = WorkChainNode() calc.set_process_state(state) @@ -117,7 +151,7 @@ def example_processes(): calc.pause() calc.store() - calcs.append(calc.pk) + calcs.append(calc.uuid) return calcs @@ -158,7 +192,7 @@ def default_groups(): test_user_2 = orm.User(email='stravinsky@symphony.org', first_name='Igor', last_name='Stravinsky').store() group_1 = orm.Group(label='test_label_1', user=test_user_1).store() group_2 = orm.Group(label='test_label_2', user=test_user_2).store() - return [group_1.pk, group_2.pk] + return [group_1.uuid, group_2.uuid] @pytest.fixture(scope='function') @@ -169,18 +203,18 @@ def default_nodes(): node_3 = orm.Str('test_string').store() node_4 = orm.Bool(False).store() - return [node_1.pk, node_2.pk, node_3.pk, node_4.pk] + return [node_1.uuid, node_2.uuid, node_3.uuid, node_4.uuid] @pytest.fixture(scope='function') def array_data_node(): - """Populate database with downloadable node (implmenting a _prepare_* function).""" + """Populate database with downloadable node (implementing a _prepare_* function).""" return orm.ArrayData(np.arange(4)).store() @pytest.fixture(scope='function') -def authenticate(): +def authenticate(app): """Authenticate user. Since this goes via modifying the app, undo modifications afterwards. @@ -213,7 +247,7 @@ def mutate_mapping( @pytest.fixture def orm_regression(data_regression): - """A variant of data_regression.check, that replaces nondetermistic fields (like uuid).""" + """A variant of data_regression.check, that replaces non-deterministic fields (like uuid).""" def _func( data: dict, @@ -289,7 +323,7 @@ def _func( level_name: str = 'level 1', message='', node: Optional[orm.nodes.Node] = None, - ) -> orm.Comment: + ) -> orm.Log: orm_node = node or orm.Data().store() return orm.Log(datetime.now(pytz.UTC), loggername, level_name, orm_node.pk, message=message).store() diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 00000000..09a9bb01 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,12 @@ +# test main application + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +def test_read_only_mode(read_only_app: FastAPI): + client = TestClient(read_only_app) + response = client.get('/computers/') + assert response.status_code == 200 + response = client.post('/computers/', json={'name': 'new_computer'}) + assert response.status_code == 405 diff --git a/tests/test_auth.py b/tests/test_auth.py index c2d7f9c4..31e05ed4 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,15 +2,11 @@ from fastapi.testclient import TestClient -from aiida_restapi import app -client = TestClient(app) - - -def test_authenticate_user(): +def test_authenticate_user(client: TestClient): """Test authenticating as a user.""" # authenticate with username and password - response = client.post('/token', data={'username': 'johndoe@example.com', 'password': 'secret'}) + response = client.post('/auth/token', data={'username': 'johndoe@example.com', 'password': 'secret'}) assert response.status_code == 200, response.content token = response.json()['access_token'] diff --git a/tests/test_computers.py b/tests/test_computers.py index 4d98cb9a..f4ae8f0f 100644 --- a/tests/test_computers.py +++ b/tests/test_computers.py @@ -1,40 +1,36 @@ """Test the /computers endpoint""" +from __future__ import annotations -def test_get_computers(default_computers, client): # pylint: disable=unused-argument - """Test listing existing computer.""" - response = client.get('/computers/') +import pytest +from aiida import orm +from fastapi.testclient import TestClient - assert response.status_code == 200 - assert len(response.json()) == 2 +def test_get_computer_projectable_properties(client: TestClient): + """Test get projectable properties for computer.""" + response = client.get('/computers/projections') + assert response.status_code == 200 + assert response.json() == sorted(orm.Computer.fields.keys()) -def test_get_computers_projectable(client): - """Test get projectable properites for computer.""" - response = client.get('/computers/projectable_properties') +@pytest.mark.usefixtures('default_computers') +def test_get_computers(client: TestClient): + """Test listing existing computer.""" + response = client.get('/computers/') assert response.status_code == 200 - assert response.json() == [ - 'id', - 'uuid', - 'label', - 'hostname', - 'scheduler_type', - 'transport_type', - 'metadata', - 'description', - ] - - -def test_get_single_computers(default_computers, client): # pylint: disable=unused-argument - """Test retrieving a single computer.""" + assert len(response.json()['results']) == 2 + +def test_get_computer(client: TestClient, default_computers: list[int | None]): + """Test retrieving a single computer.""" for comp_id in default_computers: response = client.get(f'/computers/{comp_id}') assert response.status_code == 200 -def test_create_computer(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_computer(client: TestClient): """Test creating a new computer.""" response = client.post( '/computers', @@ -46,7 +42,6 @@ def test_create_computer(client, authenticate): # pylint: disable=unused-argume }, ) assert response.status_code == 200, response.content - response = client.get('/computers') - computers = [comp['label'] for comp in response.json()] + computers = [comp['label'] for comp in response.json()['results']] assert 'test_comp' in computers diff --git a/tests/test_daemon.py b/tests/test_daemon.py index e3124a57..5efdaa44 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -3,13 +3,9 @@ import pytest from fastapi.testclient import TestClient -from aiida_restapi import app - -client = TestClient(app) - @pytest.mark.usefixtures('stopped_daemon_client', 'authenticate') -def test_status_and_start(): +def test_status_and_start(client: TestClient): """Test ``/daemon/status`` when the daemon is not running and ``/daemon/start``.""" response = client.get('/daemon/status') assert response.status_code == 200, response.content @@ -30,7 +26,7 @@ def test_status_and_start(): @pytest.mark.usefixtures('started_daemon_client', 'authenticate') -def test_status_and_stop(): +def test_status_and_stop(client: TestClient): """Test ``/daemon/status`` when the daemon is running and ``/daemon/stop``.""" response = client.get('/daemon/status') assert response.status_code == 200, response.content diff --git a/tests/test_graphql/test_comments.py b/tests/test_graphql/test_comments.py index 93711cd7..3e68556f 100644 --- a/tests/test_graphql/test_comments.py +++ b/tests/test_graphql/test_comments.py @@ -13,7 +13,7 @@ def test_comment(create_comment, orm_regression): fields = field_names_from_orm(type(comment)) schema = create_schema([CommentQueryPlugin]) client = Client(schema) - executed = client.execute('{ comment(id: %r) { %s } }' % (comment.id, ' '.join(fields))) + executed = client.execute('{ comment(id: %r) { %s } }' % (comment.pk, ' '.join(fields))) orm_regression(executed) diff --git a/tests/test_graphql/test_computers.py b/tests/test_graphql/test_computers.py index 036c1c56..8875246c 100644 --- a/tests/test_graphql/test_computers.py +++ b/tests/test_graphql/test_computers.py @@ -13,7 +13,7 @@ def test_computer(create_computer, orm_regression): fields = field_names_from_orm(type(computer)) schema = create_schema([ComputerQueryPlugin]) client = Client(schema) - executed = client.execute('{ computer(id: %r) { %s } }' % (computer.id, ' '.join(fields))) + executed = client.execute('{ computer(id: %r) { %s } }' % (computer.pk, ' '.join(fields))) orm_regression(executed) @@ -24,7 +24,7 @@ def test_computer_nodes(create_computer, create_node, orm_regression): create_node(label='node 2', computer=computer) schema = create_schema([ComputerQueryPlugin]) client = Client(schema) - executed = client.execute('{ computer(id: %r) { nodes { count rows{ label } } } }' % (computer.id)) + executed = client.execute('{ computer(id: %r) { nodes { count rows{ label } } } }' % (computer.pk)) orm_regression(executed) diff --git a/tests/test_filter_syntax.py b/tests/test_graphql/test_filter_syntax.py similarity index 95% rename from tests/test_filter_syntax.py rename to tests/test_graphql/test_filter_syntax.py index 5d7ebb70..a1a2d98d 100644 --- a/tests/test_filter_syntax.py +++ b/tests/test_graphql/test_filter_syntax.py @@ -4,7 +4,7 @@ import pytest -from aiida_restapi.filter_syntax import parse_filter_str +from aiida_restapi.graphql.filter_syntax import parse_filter_str @pytest.mark.parametrize( diff --git a/tests/test_graphql/test_full.py b/tests/test_graphql/test_full.py index a1a811a0..ad1e55c4 100644 --- a/tests/test_graphql/test_full.py +++ b/tests/test_graphql/test_full.py @@ -9,5 +9,5 @@ def test_full(create_node, orm_regression): """Test loading the full schema.""" node = create_node(label='node 1') client = Client(SCHEMA) - executed = client.execute('{ aiidaVersion node(id: %r) { label } }' % (node.id)) + executed = client.execute('{ aiidaVersion node(id: %r) { label } }' % (node.pk)) orm_regression(executed) diff --git a/tests/test_graphql/test_groups.py b/tests/test_graphql/test_groups.py index 0e0b1073..80a41f7f 100644 --- a/tests/test_graphql/test_groups.py +++ b/tests/test_graphql/test_groups.py @@ -13,7 +13,7 @@ def test_group(create_group, orm_regression): fields = field_names_from_orm(type(group)) schema = create_schema([GroupQueryPlugin]) client = Client(schema) - executed = client.execute('{ group(id: %r) { %s } }' % (group.id, ' '.join(fields))) + executed = client.execute('{ group(id: %r) { %s } }' % (group.pk, ' '.join(fields))) orm_regression(executed) @@ -34,7 +34,7 @@ def test_group_nodes(create_group, create_node, orm_regression): group.add_nodes([create_node(label='node 1'), create_node(label='node 2')]) schema = create_schema([GroupQueryPlugin]) client = Client(schema) - executed = client.execute('{ group(id: %r) { nodes { count rows{ label } } } }' % (group.id)) + executed = client.execute('{ group(id: %r) { nodes { count rows{ label } } } }' % (group.pk)) orm_regression(executed) diff --git a/tests/test_graphql/test_logs.py b/tests/test_graphql/test_logs.py index 3f5a6f18..ca92ce54 100644 --- a/tests/test_graphql/test_logs.py +++ b/tests/test_graphql/test_logs.py @@ -13,7 +13,7 @@ def test_log(create_log, orm_regression): fields = field_names_from_orm(type(log)) schema = create_schema([LogQueryPlugin]) client = Client(schema) - executed = client.execute('{ log(id: %r) { %s } }' % (log.id, ' '.join(fields))) + executed = client.execute('{ log(id: %r) { %s } }' % (log.pk, ' '.join(fields))) orm_regression(executed) diff --git a/tests/test_graphql/test_nodes.py b/tests/test_graphql/test_nodes.py index 18ee144c..dc944af9 100644 --- a/tests/test_graphql/test_nodes.py +++ b/tests/test_graphql/test_nodes.py @@ -14,7 +14,7 @@ def test_node(create_node, orm_regression): fields = field_names_from_orm(type(node)) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute('{ node(id: %r) { %s } }' % (node.id, ' '.join(fields))) + executed = client.execute('{ node(id: %r) { %s } }' % (node.pk, ' '.join(fields))) orm_regression(executed) @@ -25,7 +25,7 @@ def test_node_logs(create_node, create_log, orm_regression): create_log(message='log 2', node=node) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute('{ node(id: %r) { logs { count rows{ message } } } }' % (node.id)) + executed = client.execute('{ node(id: %r) { logs { count rows{ message } } } }' % (node.pk)) orm_regression(executed) @@ -36,7 +36,7 @@ def test_node_comments(create_node, create_comment, orm_regression): create_comment(content='comment 2', node=node) schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute('{ node(id: %r) { comments { count rows{ content } } } }' % (node.id)) + executed = client.execute('{ node(id: %r) { comments { count rows{ content } } } }' % (node.pk)) orm_regression(executed) @@ -53,7 +53,7 @@ def test_node_incoming(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) executed = client.execute( - '{ node(id: %r) { incoming { count rows{ node { label } link { label type } } } } }' % (node.id) + '{ node(id: %r) { incoming { count rows{ node { label } link { label type } } } } }' % (node.pk) ) orm_regression(executed) @@ -72,7 +72,7 @@ def test_node_outgoing(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) executed = client.execute( - '{ node(id: %r) { outgoing { count rows{ node { label } link { label type } } } } }' % (node.id) + '{ node(id: %r) { outgoing { count rows{ node { label } link { label type } } } } }' % (node.pk) ) orm_regression(executed) @@ -89,7 +89,7 @@ def test_node_ancestors(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute('{ node(id: %r) { ancestors { count rows{ label } } } }' % (node.id)) + executed = client.execute('{ node(id: %r) { ancestors { count rows{ label } } } }' % (node.pk)) orm_regression(executed) @@ -106,7 +106,7 @@ def test_node_descendants(create_node, orm_regression): schema = create_schema([NodeQueryPlugin]) client = Client(schema) - executed = client.execute('{ node(id: %r) { descendants { count rows{ label } } } }' % (node.id)) + executed = client.execute('{ node(id: %r) { descendants { count rows{ label } } } }' % (node.pk)) orm_regression(executed) diff --git a/tests/test_graphql/test_users.py b/tests/test_graphql/test_users.py index cd6b52d0..a753333c 100644 --- a/tests/test_graphql/test_users.py +++ b/tests/test_graphql/test_users.py @@ -13,7 +13,7 @@ def test_user(create_user, orm_regression): fields = field_names_from_orm(type(user)) schema = create_schema([UserQueryPlugin]) client = Client(schema) - executed = client.execute('{ user(id: %r) { %s } }' % (user.id, ' '.join(fields))) + executed = client.execute('{ user(id: %r) { %s } }' % (user.pk, ' '.join(fields))) orm_regression(executed) diff --git a/tests/test_groups.py b/tests/test_groups.py index 09e245d1..442ec65c 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -1,52 +1,40 @@ """Test the /groups endpoint""" +from __future__ import annotations -def test_get_group(default_groups, client): # pylint: disable=unused-argument +import pytest +from aiida import orm +from fastapi.testclient import TestClient + + +def test_get_group_projectable_properties(client: TestClient): + """Test get projectable properties for group.""" + response = client.get('/groups/projections') + assert response.status_code == 200 + assert response.json() == sorted(orm.Group.fields.keys()) + + +@pytest.mark.usefixtures('default_groups') +def test_get_groups(client: TestClient): """Test listing existing groups.""" response = client.get('/groups') assert response.status_code == 200 - assert len(response.json()) == 2 - + assert len(response.json()['results']) == 2 -def test_get_group_projectable(client): - """Test get projectable properites for group.""" - response = client.get('/groups/projectable_properties') - assert response.status_code == 200 - assert response.json() == [ - 'id', - 'uuid', - 'label', - 'type_string', - 'description', - 'extras', - 'time', - 'user_id', - ] - - -def test_get_single_group(default_groups, client): # pylint: disable=unused-argument +def test_get_group(client: TestClient, default_groups: list[str]): """Test retrieving a single group.""" - for group_id in default_groups: response = client.get(f'/groups/{group_id}') assert response.status_code == 200 -def test_create_group(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_group(client: TestClient): """Test creating a new group.""" response = client.post('/groups', json={'label': 'test_label_create'}) assert response.status_code == 200, response.content - + assert response.json()['user'] response = client.get('/groups') - first_names = [group['label'] for group in response.json()] - + first_names = [group['label'] for group in response.json()['results']] assert 'test_label_create' in first_names - - -def test_create_group_returns_user_id(client, authenticate): # pylint: disable=unused-argument - """Test creating a new group returns user_id.""" - response = client.post('/groups', json={'label': 'test_label_create'}) - assert response.status_code == 200, response.content - - assert response.json()['user_id'] diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index b28416b5..00000000 --- a/tests/test_models.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Test that all aiida entity models can be loaded loaded into pydantic models.""" - -from aiida import orm - -from aiida_restapi import models - - -def replace_dynamic(data: dict) -> dict: - """Replace dynamic fields with their type name.""" - for key in ['id', 'uuid', 'dbnode_id', 'user_id', 'mtime', 'ctime', 'time']: - if key in data: - data[key] = type(data[key]).__name__ - return data - - -def test_comment_get_entities(data_regression): - """Test ``Comment.get_entities``""" - orm_user = orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() - orm_node = orm.Data().store() - orm.Comment(orm_node, orm_user, 'content').store() - py_comments = models.Comment.get_entities(order_by=['id']) - data_regression.check([replace_dynamic(c.dict()) for c in py_comments]) - - -def test_user_get_entities(data_regression): - """Test ``User.get_entities``""" - orm.User(email='verdi@opera.net', first_name='Giuseppe', last_name='Verdi').store() - py_users = models.User.get_entities(order_by=['id']) - data_regression.check([replace_dynamic(c.dict()) for c in py_users]) - - -def test_computer_get_entities(data_regression): - """Test ``Computer.get_entities``""" - orm.Computer( - label='test_comp_1', - hostname='localhost_1', - transport_type='core.local', - scheduler_type='core.pbspro', - ).store() - py_computer = models.Computer.get_entities() - data_regression.check([replace_dynamic(c.dict()) for c in py_computer]) - - -def test_group_get_entities(data_regression): - """Test ``Group.get_entities``""" - orm.Group(label='regression_label_1', description='regrerssion_test').store() - py_group = models.Group.get_entities(order_by=['id']) - data_regression.check([replace_dynamic(c.dict()) for c in py_group]) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 72c84443..200a34b1 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,34 +1,167 @@ """Test the /nodes endpoint""" +from __future__ import annotations + import io import json +import typing as t import pytest +from aiida import orm +from fastapi.testclient import TestClient +from httpx import AsyncClient +if t.TYPE_CHECKING: + from pydantic import BaseModel -def test_get_nodes_projectable(client): - """Test get projectable properites for nodes.""" - response = client.get('/nodes/projectable_properties') +def test_get_node_projectable_properties(client: TestClient): + """Test get projectable properties for nodes.""" + response = client.get('/nodes/projections') assert response.status_code == 200 - assert response.json() == [ - 'id', - 'uuid', - 'node_type', - 'process_type', - 'label', - 'description', - 'ctime', - 'mtime', - 'user_id', - 'dbcomputer_id', - 'attributes', - 'extras', - 'repository_metadata', - ] + assert response.json() == sorted(orm.Node.fields.keys()) + + +def test_get_node_projectable_properties_by_type(client: TestClient): + """Test get projectable properties for nodes by valid type.""" + response = client.get('/nodes/projections?type=data.core.int.Int.') + assert response.status_code == 200 + result = response.json() + assert result == sorted(orm.Int.fields.keys()) + assert 'attributes.source' in result + assert 'attributes.value' in result + + +def test_get_node_projectable_properties_by_invalid_type(client: TestClient): + """Test get projectable properties for nodes with invalid type.""" + response = client.get('/nodes/projections?type=this_is_not_a_valid_type') + assert response.status_code == 422 + + +def test_get_node_schema(client: TestClient): + """Test get schema for nodes.""" + response = client.get('/nodes/schema') + assert response.status_code == 200 + result = response.json() + assert 'properties' in result + assert sorted(result['properties'].keys()) == sorted(orm.Node.fields.keys()) + assert 'attributes' in result['properties'] + attributes = result['properties']['attributes'] + assert '$ref' in attributes + assert attributes['$ref'].endswith('AttributesModel') + assert '$defs' in result + assert 'AttributesModel' in result['$defs'] + assert not result['$defs']['AttributesModel']['properties'] + + +@pytest.mark.parametrize( + 'which, model, name', + [ + ['get', orm.Int.Model, 'AttributesModel'], + ['post', orm.Int.CreateModel, 'AttributesCreateModel'], + ], +) +def test_get_node_schema_by_type(client: TestClient, which: str, model: type[BaseModel], name: str): + """Test get schema for nodes by valid type.""" + response = client.get(f'/nodes/schema?type=data.core.int.Int.&which={which}') + assert response.status_code == 200 + result = response.json() + assert 'properties' in result + assert sorted(result['properties'].keys()) == sorted(model.model_fields.keys()) + assert 'attributes' in result['properties'] + attributes = result['properties']['attributes'] + assert '$ref' in attributes + assert attributes['$ref'].endswith(name) + assert '$defs' in result + assert name in result['$defs'] + assert result['$defs'][name]['title'] == f'Int{name}' + assert result['$defs'][name]['properties'] + fields = orm.Int.fields.keys() + assert all(f'attributes.{key}' in fields for key in result['$defs'][name]['properties'].keys()) + + +def test_get_node_schema_by_invalid_type(client: TestClient): + """Test get schema for nodes with invalid type.""" + response = client.get('/nodes/schema?type=this_is_not_a_valid_type') + assert response.status_code == 422 + + +@pytest.mark.usefixtures('default_nodes') +def test_get_nodes(client: TestClient): + """Test listing existing nodes.""" + response = client.get('/nodes') + assert response.status_code == 200 + assert len(response.json()['results']) == 4 + result = next(iter(response.json()['results']), None) + assert result is not None + assert set(result.keys()) == {'pk', 'uuid', 'node_type', 'label', 'description', 'ctime', 'mtime', 'user'} + + +@pytest.mark.usefixtures('default_nodes') +def test_get_nodes_by_type(client: TestClient): + """Test listing existing nodes by type.""" + filters = {'node_type': {'in': ['data.core.int.Int.', 'data.core.float.Float.']}} + response = client.get(f'/nodes?filters={json.dumps(filters)}') + assert response.status_code == 200 + results = response.json()['results'] + assert len(results) == 2 + assert any(result['node_type'] == 'data.core.int.Int.' for result in results) + assert any(result['node_type'] == 'data.core.float.Float.' for result in results) + + +@pytest.mark.usefixtures('default_nodes') +def test_get_nodes_with_filters(client: TestClient): + """Test listing existing nodes with filters.""" + filters = {'attributes.value': 1.1} + response = client.get(f'/nodes?filters={json.dumps(filters)}') + assert response.status_code == 200 + results = response.json()['results'] + assert len(results) == 1 + assert results[0]['node_type'] == 'data.core.float.Float.' + + # Attributes are excluded, so we need to check separately by id + check = client.get(f'/nodes/{results[0]["uuid"]}/attributes') + assert check.status_code == 200 + assert check.json()['value'] == 1.1 + + +@pytest.mark.usefixtures('default_nodes') +def test_get_nodes_in_order(client: TestClient): + """Test listing existing nodes in order.""" + order_by = {'ctime': 'desc'} + response = client.get(f'/nodes?order_by={json.dumps(order_by)}') + assert response.status_code == 200 + results = response.json()['results'] + assert len(results) == 4 + ctimes = [result['ctime'] for result in results] + assert ctimes == sorted(ctimes, reverse=True) + + +@pytest.mark.usefixtures('default_nodes') +def test_get_nodes_pagination(client: TestClient): + """Test listing existing nodes with pagination.""" + response = client.get('/nodes?page_size=2&page=1') + assert response.status_code == 200 + results = response.json()['results'] + assert len(results) == 2 + assert all(result['pk'] in (1, 2) for result in results) + + response = client.get('/nodes?page_size=2&page=2') + assert response.status_code == 200 + results = response.json()['results'] + assert len(results) == 2 + assert all(result['pk'] in (3, 4) for result in results) -def test_get_download_formats(client): +def test_get_node(client: TestClient, default_nodes: list[str | None]): + """Test retrieving a single nodes.""" + for node_id in default_nodes: + response = client.get(f'/nodes/{node_id}') + assert response.status_code == 200, response.content + assert response.json()['uuid'] == node_id + + +def test_get_download_formats(client: TestClient): """Test get download formats for nodes.""" response = client.get('/nodes/download_formats') @@ -65,141 +198,169 @@ def test_get_download_formats(client): raise AssertionError(f'The value {value} in key {key!r} is not contained in the response: {response_json}') -def test_get_single_nodes(default_nodes, client): # pylint: disable=unused-argument - """Test retrieving a single nodes.""" - - for nodes_id in default_nodes: - response = client.get(f'/nodes/{nodes_id}') - assert response.status_code == 200 - - -def test_get_nodes(default_nodes, client): # pylint: disable=unused-argument - """Test listing existing nodes.""" - response = client.get('/nodes') +def test_get_node_repository_metadata(client: TestClient, array_data_node: orm.ArrayData): + """Test retrieving repository metadata for a node.""" + response = client.get(f'/nodes/{array_data_node.uuid}/repo/metadata') assert response.status_code == 200 - assert len(response.json()) == 4 - - -def test_create_dict(client, authenticate): # pylint: disable=unused-argument + result = response.json() + default = orm.ArrayData.default_array_name + '.npy' + assert default in result + assert all(t in result[default] for t in ('type', 'binary', 'size', 'download')) + assert result[default]['type'] == 'FILE' + assert result[default]['binary'] is True + assert 'zipped' in result + assert all(t in result['zipped'] for t in ('type', 'binary', 'size', 'download')) + assert result['zipped']['type'] == 'FILE' + assert result['zipped']['binary'] is True + + +@pytest.mark.usefixtures('authenticate') +def test_create_dict(client: TestClient): """Test creating a new dict.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.dict', - 'attributes': {'x': 1, 'y': 2}, + 'node_type': 'data.core.dict.Dict.', 'label': 'test_dict', + 'attributes': {'value': {'x': 1, 'y': 2}}, }, ) assert response.status_code == 200, response.content @pytest.mark.anyio -async def test_create_code(default_computers, async_client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +async def test_create_code(async_client: AsyncClient, default_computers: list[int | None]): """Test creating a new Code.""" - for comp_id in default_computers: + computer = orm.load_computer(comp_id) response = await async_client.post( '/nodes', json={ - 'entry_point': 'core.code.installed', - 'dbcomputer_id': comp_id, - 'attributes': {'filepath_executable': '/bin/true'}, + 'node_type': 'data.core.code.installed.InstalledCode.', 'label': 'test_code', + 'attributes': { + 'filepath_executable': '/bin/true', + 'computer': computer.label, + }, }, ) assert response.status_code == 200, response.content -def test_create_list(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_list(client: TestClient): """Test creating a new list.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.list', - 'attributes': {'list': [2, 3]}, + 'node_type': 'data.core.list.List.', + 'attributes': {'value': [2, 3]}, }, ) - assert response.status_code == 200, response.content -def test_create_int(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_int(client: TestClient): """Test creating a new Int.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.int', + 'node_type': 'data.core.int.Int.', 'attributes': {'value': 6}, }, ) assert response.status_code == 200, response.content -def test_create_float(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_float(client: TestClient): """Test creating a new Float.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.float', + 'node_type': 'data.core.float.Float.', 'attributes': {'value': 6.6}, }, ) assert response.status_code == 200, response.content -def test_create_string(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_string(client: TestClient): """Test creating a new string.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.str', + 'node_type': 'data.core.str.Str.', 'attributes': {'value': 'test_string'}, }, ) assert response.status_code == 200, response.content -def test_create_bool(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_bool(client: TestClient): """Test creating a new Bool.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.bool', - 'attributes': {'value': 'True'}, + 'node_type': 'data.core.bool.Bool.', + 'attributes': {'value': True}, }, ) assert response.status_code == 200, response.content -def test_create_structure_data(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_structure_data(client: TestClient): """Test creating a new StructureData.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.structure', - 'process_type': None, + 'node_type': 'data.core.structure.StructureData.', 'description': '', 'attributes': { - 'cell': [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], - 'pbc': [True, True, True], - 'ase': None, - 'pymatgen': None, - 'pymatgen_structure': None, - 'pymatgen_molecule': None, + 'cell': [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ], + 'pbc1': True, + 'pbc2': True, + 'pbc3': True, + 'kinds': [ + { + 'name': 'H', + 'mass': 1.00784, + 'symbols': ['H'], + 'weights': [1.0], + } + ], + 'sites': [ + { + 'position': [0.0, 0.0, 0.0], + 'kind_name': 'H', + }, + { + 'position': [0.5, 0.5, 0.5], + 'kind_name': 'H', + }, + ], }, }, ) - assert response.status_code == 200, response.content -def test_create_orbital_data(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_orbital_data(client: TestClient): """Test creating a new OrbitalData.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.orbital', - 'process_type': None, + 'node_type': 'data.core.orbital.OrbitalData.', 'description': '', 'attributes': { 'orbital_dicts': [ @@ -222,118 +383,246 @@ def test_create_orbital_data(client, authenticate): # pylint: disable=unused-ar }, }, ) - assert response.status_code == 200, response.content -def test_create_single_file_upload(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_single_file(client: TestClient): """Testing file upload""" - test_file = { - 'upload_file': ( - 'test_file.txt', - io.BytesIO(b'Some test strings'), - 'multipart/form-data', + files = [ + ( + 'files', + ( + 'test_file.txt', + io.BytesIO(b'Some test strings'), + 'text/plain', + ), + ) + ] + data = { + 'params': json.dumps( + { + 'node_type': 'data.core.singlefile.SinglefileData.', + } ) } + + response = client.post('/nodes/file-upload', files=files, data=data) + assert response.status_code == 200, response.json() + + check = client.get(f'/nodes/{response.json()["uuid"]}/repo/metadata') + assert check.status_code == 200, check.content + result = check.json() + assert 'test_file.txt' in result + assert result['test_file.txt']['type'] == 'FILE' + assert result['test_file.txt']['size'] == len(b'Some test strings') + assert result['test_file.txt']['binary'] is False + + +@pytest.mark.usefixtures('authenticate') +def test_create_single_file_binary(client: TestClient): + """Testing binary file upload""" + binary_content = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09' + files = [ + ( + 'files', + ( + 'binary_file.bin', + io.BytesIO(binary_content), + 'application/octet-stream', + ), + ) + ] data = { 'params': json.dumps( { - 'entry_point': 'core.singlefile', - 'process_type': None, - 'description': 'Testing single upload file', - 'attributes': {}, + 'node_type': 'data.core.singlefile.SinglefileData.', } ) } - response = client.post('/nodes/singlefile', files=test_file, data=data) + response = client.post('/nodes/file-upload', files=files, data=data) + assert response.status_code == 200, response.json() + check = client.get(f'/nodes/{response.json()["uuid"]}/repo/metadata') + assert check.status_code == 200, check.content + result = check.json() + assert 'binary_file.bin' in result + assert result['binary_file.bin']['type'] == 'FILE' + assert result['binary_file.bin']['size'] == len(binary_content) + assert result['binary_file.bin']['binary'] is True + + +@pytest.mark.usefixtures('authenticate') +def test_create_folder_data(client: TestClient): + """Testing folder upload""" + files = [ + ( + 'files', + ( + 'folder/file1.txt', + io.BytesIO(b'Content of file 1'), + 'text/plain', + ), + ), + ( + 'files', + ( + 'folder/file2.txt', + io.BytesIO(b'Content of file 2'), + 'text/plain', + ), + ), + ] + data = { + 'params': json.dumps( + { + 'node_type': 'data.core.folder.FolderData.', + } + ) + } + + response = client.post('/nodes/file-upload', files=files, data=data) assert response.status_code == 200, response.json() + check = client.get(f'/nodes/{response.json()["uuid"]}/repo/metadata') + assert check.status_code == 200, check.content + result = check.json() + assert 'folder' in result + assert result['folder']['type'] == 'DIRECTORY' + assert 'objects' in result['folder'] + objects = result['folder']['objects'] + assert len(objects) == 2 + for file in files: + filename = file[1][0].split('/', 1)[1] + assert filename in objects + assert objects[filename]['type'] == 'FILE' + expected_size = len(file[1][1].getvalue()) + assert objects[filename]['size'] == expected_size + assert objects[filename]['binary'] is False + + +@pytest.mark.usefixtures('authenticate') +def test_create_node_with_files_has_zipped_metadata(client: TestClient): + """Test link for zipped repo content is present when creating node with files.""" + files = [ + ( + 'files', + ( + 'file1.txt', + io.BytesIO(b'Content of file 1'), + 'text/plain', + ), + ), + ( + 'files', + ( + 'file2.txt', + io.BytesIO(b'Content of file 2'), + 'text/plain', + ), + ), + ] + data = { + 'params': json.dumps( + { + 'node_type': 'data.core.int.Int.', + 'attributes': {'value': 42}, + } + ) + } -def test_create_node_wrong_value(client, authenticate): # pylint: disable=unused-argument - """Test creating a new node with wrong value.""" - response = client.post( - '/nodes', - json={ - 'entry_point': 'core.float', - 'attributes': {'value': 'tests'}, - }, - ) - assert response.status_code == 400, response.content + response = client.post('/nodes/file-upload', files=files, data=data) + assert response.status_code == 200, response.json() + check = client.get(f'/nodes/{response.json()["uuid"]}/repo/metadata') + assert check.status_code == 200, check.content + result = check.json() + assert 'zipped' in result + assert result['zipped']['type'] == 'FILE' + assert result['zipped']['binary'] is True + assert result['zipped']['size'] == sum(len(file[1][1].getvalue()) for file in files) + + +@pytest.mark.parametrize( + 'node_type, value', + [ + ('data.core.int.Int.', 'test'), + ('data.core.float.Float.', [1, 2, 3]), + ('data.core.str.Str.', 5), + ], +) +@pytest.mark.usefixtures('authenticate') +def test_create_node_wrong_value(client: TestClient, node_type: str, value: t.Any): + """Test creating a new node with wrong value.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.int', - 'attributes': {'value': 'tests'}, + 'node_type': node_type, + 'attributes': {'value': value}, }, ) - assert response.status_code == 400, response.content + assert response.status_code == 422, response.content -def test_create_node_wrong_attribute(client, authenticate): # pylint: disable=unused-argument - """Test adding node with wrong attributes.""" +@pytest.mark.usefixtures('default_computers', 'authenticate') +def test_create_unknown_entry_point(client: TestClient): + """Test error message when specifying unknown ``entry_point``.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.str', - 'attributes': {'value1': 5}, + 'node_type': 'data.core.nonexistent.NonExistentType.', + 'label': 'test_code', }, ) - assert response.status_code == 400, response.content + assert response.status_code == 422, response.content -def test_create_unknown_entry_point(default_computers, client, authenticate): # pylint: disable=unused-argument - """Test error message when specifying unknown ``entry_point``.""" +@pytest.mark.usefixtures('default_computers', 'authenticate') +def test_create_additional_attribute(client: TestClient): + """Test adding additional properties are ignored.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.not.existing.entry.point', - 'label': 'test_code', + 'node_type': 'data.core.int.Int.', + 'attributes': { + 'value': 5, + 'extra_thing': 'should_not_be_here', + }, }, ) - assert response.status_code == 404, response.content - assert response.json()['detail'] == "Entry point 'core.not.existing.entry.point' not found in group 'aiida.data'" - - -def test_create_additional_attribute(default_computers, client, authenticate): # pylint: disable=unused-argument - """Test adding additional properties returns errors.""" + assert response.status_code == 200, response.content - for comp_id in default_computers: - response = client.post( - '/nodes', - json={ - 'uuid': '3', - 'entry_point': 'core.code.installed', - 'dbcomputer_id': comp_id, - 'attributes': {'filepath_executable': '/bin/true'}, - 'label': 'test_code', - }, - ) - assert response.status_code == 422, response.content + check = client.get(f'/nodes/{response.json()["uuid"]}/attributes') + assert check.status_code == 200, check.content + result = check.json() + assert 'value' in result + assert 'extra_thing' not in result -def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_create_bool_with_extra(client: TestClient): """Test creating a new Bool with extra.""" response = client.post( '/nodes', json={ - 'entry_point': 'core.bool', - 'attributes': {'value': 'True'}, + 'node_type': 'data.core.bool.Bool.', + 'attributes': {'value': True}, 'extras': {'extra_one': 'value_1', 'extra_two': 'value_2'}, }, ) + assert response.status_code == 200, response.content - check_response = client.get(f"/nodes/{response.json()['id']}") - - assert check_response.status_code == 200, response.content - assert check_response.json()['extras']['extra_one'] == 'value_1' - assert check_response.json()['extras']['extra_two'] == 'value_2' + # We exclude extras from the node response, so we check by retrieving them separately + check = client.get(f'/nodes/{response.json()["uuid"]}/extras') + assert check.status_code == 200, check.content + result = check.json() + assert result['extra_one'] == 'value_1' + assert result['extra_two'] == 'value_2' @pytest.mark.anyio -async def test_get_download_node(array_data_node, async_client): +async def test_get_download_node(async_client: AsyncClient, array_data_node: orm.ArrayData): """Test download node /nodes/{nodes_id}/download. The async client is needed to avoid an error caused by an I/O operation on closed file""" @@ -351,3 +640,40 @@ async def test_get_download_node(array_data_node, async_client): response = await async_client.get(f'/nodes/{array_data_node.pk}/download') assert response.status_code == 422, response.json() assert 'Please specify the download format' in response.json()['detail'] + + +@pytest.mark.usefixtures('default_nodes') +@pytest.mark.anyio +async def test_get_statistics(async_client: AsyncClient): + """Test get statistics for nodes.""" + + from datetime import datetime + + default_user_reference_json = { + 'total': 4, + 'types': { + 'data.core.float.Float.': 1, + 'data.core.str.Str.': 1, + 'data.core.bool.Bool.': 1, + 'data.core.int.Int.': 1, + }, + 'ctime_by_day': {datetime.today().strftime('%Y-%m-%d'): 4}, + } + + # Test without specifying user, should use default user + response = await async_client.get('/nodes/statistics') + assert response.status_code == 200, response.json() + assert response.json() == default_user_reference_json + + # Test that the output is the same when we use the pk of the default user + from aiida import orm + + default_user_pk = orm.User(email='').collection.get_default().pk + response = await async_client.get(f'/nodes/statistics?user={default_user_pk}') + assert response.status_code == 200, response.json() + assert response.json() == default_user_reference_json + + # Test empty response for nonexisting user + response = await async_client.get('/nodes/statistics?user=99999') + assert response.status_code == 200, response.json() + assert response.json() == {'total': 0, 'types': {}, 'ctime_by_day': {}} diff --git a/tests/test_querybuilder.py b/tests/test_querybuilder.py new file mode 100644 index 00000000..5ff35c26 --- /dev/null +++ b/tests/test_querybuilder.py @@ -0,0 +1,101 @@ +"""Test the /querybuilder endpoint""" + +import pytest +from aiida import orm +from fastapi.testclient import TestClient + + +@pytest.mark.usefixtures('default_nodes') +def test_querybuilder_all(client: TestClient): + """Test a simple QueryBuilder request.""" + response = client.post( + '/querybuilder', + json={ + 'path': [ + { + 'entity_type': 'data.core.base.', + 'orm_base': 'node', + 'tag': 'nodes', + }, + ], + 'project': { + 'nodes': [ + 'attributes.value', + ], + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert 'results' in data + assert len(data['results']) == 4 + + +@pytest.mark.usefixtures('default_nodes') +def test_querybuilder_numeric_flat(client: TestClient): + """Test a simple QueryBuilder request.""" + response = client.post( + '/querybuilder?flat=true', + json={ + 'path': [ + { + 'entity_type': ['data.core.int.Int.', 'data.core.float.Float.'], + 'orm_base': 'node', + 'tag': 'nodes', + }, + ], + 'project': { + 'nodes': [ + 'attributes.value', + ], + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert 'results' in data + assert data['results'] == [1, 1.1] + + +def test_querybuilder_integer_in_group(client: TestClient, default_nodes: list[str], default_groups: list[str]): + """Test a QueryBuilder request filtering integers by group membership.""" + node = orm.load_node(default_nodes[0]) + group = orm.load_group(default_groups[0]) + group.add_nodes(node) + response = client.post( + '/querybuilder?flat=true', + json={ + 'path': [ + { + 'entity_type': 'group.core.', + 'orm_base': 'group', + 'tag': 'group', + }, + { + 'entity_type': 'data.core.int.Int.', + 'orm_base': 'node', + 'joining_keyword': 'with_group', + 'joining_value': 'group', + 'tag': 'node', + }, + ], + 'filters': { + 'group': { + 'pk': group.pk, + } + }, + 'project': { + 'group': [ + 'label', + ], + 'node': [ + 'pk', + 'attributes.value', + ], + }, + }, + ) + assert response.status_code == 200 + data = response.json() + assert 'results' in data + assert data['results'] == [group.label, node.pk, node.value] diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 00000000..51e738c0 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,48 @@ +# test the server router + +from bs4 import BeautifulSoup +from fastapi.testclient import TestClient + +from aiida_restapi.config import API_CONFIG + + +def test_get_server_endpoints(client: TestClient): + response = client.get('/server/endpoints') + assert response.status_code == 200 + data = response.json() + assert 'endpoints' in data + assert isinstance(data['endpoints'], list) + assert len(data['endpoints']) > 0 + for endpoint in data['endpoints']: + assert 'path' in endpoint + assert 'group' in endpoint + assert 'methods' in endpoint + assert 'description' in endpoint + + +def test_get_server_endpoints_table(client: TestClient): + response = client.get('/server/endpoints/table') + assert response.status_code == 200 + assert 'text/html' in response.headers['content-type'] + + bs = BeautifulSoup(response.text, 'html.parser') + assert bs.find('table') is not None + assert len(bs.find_all('th')) == 4 + + tbody = bs.find('tbody') + assert tbody is not None + + for row in tbody.find_all('tr'): + cols = row.find_all('td') + path = cols[0].get_text() + + # Check that the group, if not empty, is in path immediately after the prefix + if (group := cols[1].get_text()) != '-': + assert path.startswith(f'{client.base_url}{API_CONFIG["PREFIX"]}/{group}') + + # Check that those endpoints that should be links are indeed links + method = cols[2].get_text() + if method == 'GET' and 'auth' not in path and '{' not in path: + assert cols[0].find('a') is not None + else: + assert cols[0].find('a') is None diff --git a/tests/test_processes.py b/tests/test_submit.py similarity index 55% rename from tests/test_processes.py rename to tests/test_submit.py index f32e2191..867c03f5 100644 --- a/tests/test_processes.py +++ b/tests/test_submit.py @@ -3,55 +3,21 @@ import io import pytest -from aiida.orm import Dict, SinglefileData - - -def test_get_processes(example_processes, client): # pylint: disable=unused-argument - """Test listing existing processes.""" - response = client.get('/processes/') - - assert response.status_code == 200 - assert len(response.json()) == 12 - - -def test_get_processes_projectable(client): - """Test get projectable properties for processes.""" - response = client.get('/processes/projectable_properties') - - assert response.status_code == 200 - assert response.json() == [ - 'id', - 'uuid', - 'node_type', - 'process_type', - 'label', - 'description', - 'ctime', - 'mtime', - 'user_id', - 'dbcomputer_id', - 'attributes', - 'extras', - 'repository_metadata', - ] - - -def test_get_single_processes(example_processes, client): # pylint: disable=unused-argument - """Test retrieving a single processes.""" - for proc_id in example_processes: - response = client.get(f'/processes/{proc_id}') - assert response.status_code == 200 +from aiida import orm +from fastapi.testclient import TestClient +from httpx import AsyncClient @pytest.mark.anyio -async def test_add_process(default_test_add_process, async_client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +async def test_add_process(async_client: AsyncClient, default_test_add_process: list[str]): """Test adding new process""" code_id, x_id, y_id = default_test_add_process response = await async_client.post( - '/processes', + '/submit', json={ 'label': 'test_new_process', - 'process_entry_point': 'aiida.calculations:core.arithmetic.add', + 'entry_point': 'aiida.calculations:core.arithmetic.add', 'inputs': { 'code.uuid': code_id, 'x.uuid': x_id, @@ -65,14 +31,15 @@ async def test_add_process(default_test_add_process, async_client, authenticate) assert response.status_code == 200 -def test_add_process_invalid_entry_point(default_test_add_process, client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_add_process_invalid_entry_point(client: TestClient, default_test_add_process: list[str]): """Test adding new process with invalid entry point""" code_id, x_id, y_id = default_test_add_process response = client.post( - '/processes', + '/submit', json={ 'label': 'test_new_process', - 'process_entry_point': 'wrong_entry_point', + 'entry_point': 'wrong_entry_point', 'inputs': { 'code.uuid': code_id, 'x.uuid': x_id, @@ -86,14 +53,15 @@ def test_add_process_invalid_entry_point(default_test_add_process, client, authe assert response.status_code == 404 -def test_add_process_invalid_node_id(default_test_add_process, client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +def test_add_process_invalid_node_id(client: TestClient, default_test_add_process): """Test adding new process with invalid Node ID""" code_id, x_id, _ = default_test_add_process response = client.post( - '/processes', + '/submit', json={ 'label': 'test_new_process', - 'process_entry_point': 'aiida.calculations:core.arithmetic.add', + 'entry_point': 'aiida.calculations:core.arithmetic.add', 'inputs': { 'code.uuid': code_id, 'x.uuid': x_id, @@ -109,21 +77,22 @@ def test_add_process_invalid_node_id(default_test_add_process, client, authentic @pytest.mark.anyio -async def test_add_process_nested_inputs(default_test_add_process, async_client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +async def test_add_process_nested_inputs(async_client: AsyncClient, default_test_add_process): """Test adding new process that has nested inputs""" code_id, _, _ = default_test_add_process - template = Dict( + template = orm.Dict( { 'files_to_copy': [('file', 'file.txt')], } ).store() - single_file = SinglefileData(io.StringIO('content')).store() + single_file = orm.SinglefileData(io.StringIO('content')).store() response = await async_client.post( - '/processes', + '/submit', json={ 'label': 'test_new_process', - 'process_entry_point': 'aiida.calculations:core.templatereplacer', + 'entry_point': 'aiida.calculations:core.templatereplacer', 'inputs': { 'code.uuid': code_id, 'template.uuid': template.uuid, diff --git a/tests/test_users.py b/tests/test_users.py index d7ccf259..4baa637c 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,16 +1,22 @@ """Test the /users endpoint""" +from __future__ import annotations + import pytest +from aiida import orm +from fastapi.testclient import TestClient +from httpx import AsyncClient -def test_get_single_user(default_users, client): # pylint: disable=unused-argument - """Test retrieving a single user.""" - for user_id in default_users: - response = client.get(f'/users/{user_id}') - assert response.status_code == 200 +def test_get_user_projectable_properties(client: TestClient): + """Test get projectable properties for users.""" + response = client.get('/users/projections') + assert response.status_code == 200 + assert response.json() == sorted(orm.User.fields.keys()) -def test_get_users(default_users, client): # pylint: disable=unused-argument +@pytest.mark.usefixtures('default_users') +def test_get_users(client: TestClient): """Test listing existing users. Note: Besides the default users set up by the pytest fixture the test profile @@ -18,23 +24,22 @@ def test_get_users(default_users, client): # pylint: disable=unused-argument """ response = client.get('/users') assert response.status_code == 200 - assert len(response.json()) == 2 + 1 + assert len(response.json()['results']) == 2 + 1 + + +def test_get_user(client: TestClient, default_users: list[int | None]): + """Test retrieving a single user.""" + for user_id in default_users: + response = client.get(f'/users/{user_id}') + assert response.status_code == 200 @pytest.mark.anyio -async def test_create_user(async_client, authenticate): # pylint: disable=unused-argument +@pytest.mark.usefixtures('authenticate') +async def test_create_user(async_client: AsyncClient): """Test creating a new user.""" response = await async_client.post('/users', json={'first_name': 'New', 'email': 'aiida@localhost'}) assert response.status_code == 200, response.content - response = await async_client.get('/users') - first_names = [user['first_name'] for user in response.json()] + first_names = [user['first_name'] for user in response.json()['results']] assert 'New' in first_names - - -def test_get_users_projectable(client): - """Test get projectable properites for users.""" - response = client.get('/users/projectable_properties') - - assert response.status_code == 200 - assert response.json() == ['id', 'email', 'first_name', 'last_name', 'institution']