Skip to content

Commit d451778

Browse files
Move query_params dependency logic into QueryParams model
1 parent 107cbc5 commit d451778

File tree

7 files changed

+122
-92
lines changed

7 files changed

+122
-92
lines changed

aiida_restapi/common/query.py

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,24 @@
66
import typing as t
77

88
import pydantic as pdt
9-
from fastapi import HTTPException, Query
109

1110

1211
class QueryParams(pdt.BaseModel):
13-
filters: dict[str, t.Any] = pdt.Field(
14-
default_factory=dict,
12+
filters: dict[str, t.Any] | None = pdt.Field(
13+
None,
1514
description='AiiDA QueryBuilder filters',
1615
examples=[
17-
{'node_type': {'==': 'data.core.int.Int.'}},
18-
{'attributes.value': {'>': 42}},
16+
'{"node_type": "data.core.int.Int."}',
17+
'{"attributes.value": {">": 42}}',
1918
],
2019
)
2120
order_by: str | list[str] | dict[str, t.Any] | None = pdt.Field(
2221
None,
2322
description='Fields to sort by',
2423
examples=[
25-
{'attributes.value': 'desc'},
24+
'pk',
25+
'uuid,label',
26+
'{"attributes.value": "desc"}',
2627
],
2728
)
2829
page_size: pdt.PositiveInt = pdt.Field(
@@ -36,55 +37,26 @@ class QueryParams(pdt.BaseModel):
3637
examples=[1],
3738
)
3839

39-
40-
def query_params(
41-
filters: str | None = Query(
42-
None,
43-
description='AiiDA QueryBuilder filters as JSON string',
44-
),
45-
order_by: str | None = Query(
46-
None,
47-
description='Comma-separated list of fields to sort by',
48-
),
49-
page_size: pdt.PositiveInt = Query(
50-
10,
51-
description='Number of results per page',
52-
),
53-
page: pdt.PositiveInt = Query(
54-
1,
55-
description='Page number',
56-
),
57-
) -> QueryParams:
58-
"""Parse query parameters into a structured object.
59-
60-
:param filters: AiiDA QueryBuilder filters as JSON string.
61-
:param order_by: Comma-separated string of fields to sort by.
62-
:param page_size: Number of results per page.
63-
:param page: Page number.
64-
:return: Structured query parameters.
65-
:raises HTTPException: If filters cannot be parsed as JSON.
66-
"""
67-
query_filters: dict[str, t.Any] = {}
68-
query_order_by: str | list[str] | dict[str, t.Any] | None = None
69-
if filters:
70-
try:
71-
query_filters = json.loads(filters)
72-
except Exception as exception:
73-
raise HTTPException(
74-
status_code=400,
75-
detail=f'Could not parse filters as JSON: {exception}',
76-
) from exception
77-
if order_by:
78-
try:
79-
query_order_by = json.loads(order_by)
80-
except Exception as exception:
81-
raise HTTPException(
82-
status_code=400,
83-
detail=f'Could not parse order_by as JSON: {exception}',
84-
) from exception
85-
return QueryParams(
86-
filters=query_filters,
87-
order_by=query_order_by,
88-
page_size=page_size,
89-
page=page,
90-
)
40+
@pdt.field_validator('filters', mode='before')
41+
@classmethod
42+
def parse_filters(cls, value: t.Any) -> dict[str, t.Any] | None:
43+
if value:
44+
try:
45+
return json.loads(value)
46+
except Exception as exception:
47+
raise ValueError(f'Could not parse filters as JSON: {exception}') from exception
48+
return None
49+
50+
@pdt.field_validator('order_by', mode='before')
51+
@classmethod
52+
def parse_order_by(cls, value: t.Any) -> str | list[str] | dict[str, t.Any] | None:
53+
if value:
54+
# Due to allowing list[str] on the field, FastAPI will always convert query to a list
55+
raw: str = value[0]
56+
if raw.startswith('{') or raw.startswith('['):
57+
try:
58+
return json.loads(raw)
59+
except Exception as exception:
60+
raise ValueError(f'Could not parse order_by as JSON: {exception}') from exception
61+
return raw.split(',')
62+
return None

aiida_restapi/repository/entity.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,23 @@ def get_projectable_properties(self) -> list[str]:
4545
"""
4646
return self.entity_class.fields.keys()
4747

48-
def get_entities(self, queries: QueryParams) -> PaginatedResults[EntityModelType]:
48+
def get_entities(self, query_params: QueryParams) -> PaginatedResults[EntityModelType]:
4949
"""Get AiiDA entities with optional filtering, sorting, and/or pagination.
5050
51-
:param queries: The query parameters, including filters, order_by, page_size, and page.
51+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
5252
:return: The paginated results, including total count, current page, page size, and list of entity models.
5353
"""
54-
total = self.entity_class.collection.count(filters=queries.filters)
54+
total = self.entity_class.collection.count(filters=query_params.filters)
5555
results = self.entity_class.collection.find(
56-
filters=queries.filters,
57-
order_by=queries.order_by,
58-
limit=queries.page_size,
59-
offset=queries.page_size * (queries.page - 1),
56+
filters=query_params.filters,
57+
order_by=query_params.order_by,
58+
limit=query_params.page_size,
59+
offset=query_params.page_size * (query_params.page - 1),
6060
)
6161
return PaginatedResults(
6262
total=total,
63-
page=queries.page,
64-
page_size=queries.page_size,
63+
page=query_params.page,
64+
page_size=query_params.page_size,
6565
results=[self.to_model(result) for result in results],
6666
)
6767

aiida_restapi/repository/node.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,21 @@ def get_node_attributes(self, uuid: str) -> dict[str, t.Any]:
148148
def get_node_links(
149149
self,
150150
uuid: str,
151-
queries: QueryParams,
151+
query_params: QueryParams,
152152
direction: t.Literal['incoming', 'outgoing'],
153153
) -> PaginatedResults[NodeLinks]:
154154
"""Get the incoming links of a node.
155155
156156
:param uuid: The uuid of the node to retrieve the incoming links for.
157-
:param queries: The query parameters, including filters, order_by, page_size, and page.
157+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
158158
:param direction: Specify whether to retrieve incoming or outgoing links.
159159
:return: The paginated requested linked nodes.
160160
"""
161161
node = self.entity_class.collection.get(uuid=uuid)
162162

163163
start, end = (
164-
queries.page_size * (queries.page - 1),
165-
queries.page_size * queries.page,
164+
query_params.page_size * (query_params.page - 1),
165+
query_params.page_size * query_params.page,
166166
)
167167

168168
if direction == 'incoming':
@@ -181,8 +181,8 @@ def get_node_links(
181181

182182
return PaginatedResults(
183183
total=len(links),
184-
page=queries.page,
185-
page_size=queries.page_size,
184+
page=query_params.page,
185+
page_size=query_params.page_size,
186186
results=links,
187187
)
188188

aiida_restapi/routers/computers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fastapi import APIRouter, Depends, HTTPException, Query
1111

1212
from aiida_restapi.common.pagination import PaginatedResults
13-
from aiida_restapi.common.query import QueryParams, query_params
13+
from aiida_restapi.common.query import QueryParams
1414
from aiida_restapi.repository.entity import EntityRepository
1515

1616
from .auth import UserInDB, get_current_active_user
@@ -66,14 +66,27 @@ async def get_computer_projectable_properties() -> list[str]:
6666
)
6767
@with_dbenv()
6868
async def get_computers(
69-
queries: t.Annotated[QueryParams, Depends(query_params)],
69+
query_params: t.Annotated[
70+
QueryParams,
71+
Query(
72+
default_factory=QueryParams,
73+
description='Query parameters for filtering, sorting, and pagination.',
74+
),
75+
],
7076
) -> PaginatedResults[orm.Computer.Model]:
7177
"""Get AiiDA computers with optional filtering, sorting, and/or pagination.
7278
73-
:param queries: The query parameters, including filters, order_by, page_size, and page.
79+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
7480
:return: The paginated results, including total count, current page, page size, and list of computer models.
81+
:raises HTTPException: 422 if the query parameters are invalid,
82+
500 for other failures during retrieval.
7583
"""
76-
return repository.get_entities(queries)
84+
try:
85+
return repository.get_entities(query_params)
86+
except ValueError as exception:
87+
raise HTTPException(status_code=422, detail=str(exception)) from exception
88+
except Exception as exception:
89+
raise HTTPException(status_code=500, detail=str(exception)) from exception
7790

7891

7992
@read_router.get(

aiida_restapi/routers/groups.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fastapi import APIRouter, Depends, HTTPException, Query
1111

1212
from aiida_restapi.common.pagination import PaginatedResults
13-
from aiida_restapi.common.query import QueryParams, query_params
13+
from aiida_restapi.common.query import QueryParams
1414
from aiida_restapi.repository.entity import EntityRepository
1515

1616
from .auth import UserInDB, get_current_active_user
@@ -66,14 +66,27 @@ async def get_group_projectable_properties() -> list[str]:
6666
)
6767
@with_dbenv()
6868
async def get_groups(
69-
queries: t.Annotated[QueryParams, Depends(query_params)],
69+
query_params: t.Annotated[
70+
QueryParams,
71+
Query(
72+
default_factory=QueryParams,
73+
description='Query parameters for filtering, sorting, and pagination.',
74+
),
75+
],
7076
) -> PaginatedResults[orm.Group.Model]:
7177
"""Get AiiDA groups with optional filtering, sorting, and/or pagination.
7278
73-
:param queries: The query parameters, including filters, order_by, page_size, and page.
79+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
7480
:return: The paginated results, including total count, current page, page size, and list of group models.
81+
:raises HTTPException: 422 if the query parameters are invalid,
82+
500 for other failures during retrieval.
7583
"""
76-
return repository.get_entities(queries)
84+
try:
85+
return repository.get_entities(query_params)
86+
except ValueError as exception:
87+
raise HTTPException(status_code=422, detail=str(exception)) from exception
88+
except Exception as exception:
89+
raise HTTPException(status_code=500, detail=str(exception)) from exception
7790

7891

7992
@read_router.get(

aiida_restapi/routers/nodes.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing_extensions import TypeAlias
1616

1717
from aiida_restapi.common.pagination import PaginatedResults
18-
from aiida_restapi.common.query import QueryParams, query_params
18+
from aiida_restapi.common.query import QueryParams
1919
from aiida_restapi.config import API_CONFIG
2020
from aiida_restapi.models.node import NodeModelRegistry
2121
from aiida_restapi.repository.node import NodeLinks, NodeRepository
@@ -182,14 +182,27 @@ async def get_nodes_download_formats() -> dict[str, t.Any]:
182182
)
183183
@with_dbenv()
184184
async def get_nodes(
185-
queries: t.Annotated[QueryParams, Depends(query_params)],
185+
query_params: t.Annotated[
186+
QueryParams,
187+
Query(
188+
default_factory=QueryParams,
189+
description='Query parameters for filtering, sorting, and pagination.',
190+
),
191+
],
186192
) -> PaginatedResults[orm.Node.Model]:
187193
"""Get AiiDA nodes with optional filtering, sorting, and/or pagination.
188194
189-
:param queries: The query parameters, including filters, order_by, page_size, and page.
195+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
190196
:return: The paginated results, including total count, current page, page size, and list of node models.
197+
:raises HTTPException: 422 if the query parameters are invalid,
198+
500 for other failures during retrieval.
191199
"""
192-
return repository.get_entities(queries)
200+
try:
201+
return repository.get_entities(query_params)
202+
except ValueError as exception:
203+
raise HTTPException(status_code=422, detail=str(exception)) from exception
204+
except Exception as exception:
205+
raise HTTPException(status_code=500, detail=str(exception)) from exception
193206

194207

195208
class NodeType(pdt.BaseModel):
@@ -311,22 +324,28 @@ async def get_node_extras(uuid: str) -> dict[str, t.Any]:
311324
@with_dbenv()
312325
async def get_node_links(
313326
uuid: str,
314-
queries: t.Annotated[QueryParams, Depends(query_params)],
327+
query_params: t.Annotated[
328+
QueryParams,
329+
Query(
330+
default_factory=QueryParams,
331+
description='Query parameters for filtering, sorting, and pagination.',
332+
),
333+
],
315334
direction: t.Literal['incoming', 'outgoing'] = Query(
316335
description='Specify whether to retrieve incoming or outgoing links.',
317336
),
318337
) -> PaginatedResults[NodeLinks]:
319338
"""Get the incoming or outgoing links of a node.
320339
321340
:param uuid: The uuid of the node to retrieve the incoming links for.
322-
:param queries: The query parameters, including filters, order_by, page_size, and page.
341+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
323342
:param direction: Specify whether to retrieve incoming or outgoing links.
324343
:return: The paginated requested linked nodes.
325344
:raises HTTPException: 404 if the node with the given uuid does not exist,
326345
500 for other failures during retrieval.
327346
"""
328347
try:
329-
return repository.get_node_links(uuid, queries, direction=direction)
348+
return repository.get_node_links(uuid, query_params, direction=direction)
330349
except NotExistent as exception:
331350
raise HTTPException(status_code=404, detail=str(exception)) from exception
332351
except Exception as exception:

aiida_restapi/routers/users.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fastapi import APIRouter, Depends, HTTPException, Query
1111

1212
from aiida_restapi.common.pagination import PaginatedResults
13-
from aiida_restapi.common.query import QueryParams, query_params
13+
from aiida_restapi.common.query import QueryParams
1414
from aiida_restapi.repository.entity import EntityRepository
1515

1616
from .auth import UserInDB, get_current_active_user
@@ -66,14 +66,27 @@ async def get_user_projectable_properties() -> list[str]:
6666
)
6767
@with_dbenv()
6868
async def get_users(
69-
queries: t.Annotated[QueryParams, Depends(query_params)],
69+
query_params: t.Annotated[
70+
QueryParams,
71+
Query(
72+
default_factory=QueryParams,
73+
description='Query parameters for filtering, sorting, and pagination.',
74+
),
75+
],
7076
) -> PaginatedResults[orm.User.Model]:
7177
"""Get AiiDA users with optional filtering, sorting, and/or pagination.
7278
73-
:param queries: The query parameters, including filters, order_by, page_size, and page.
79+
:param query_params: The query parameters, including filters, order_by, page_size, and page.
7480
:return: The paginated results, including total count, current page, page size, and list of user models.
81+
:raises HTTPException: 422 if the query parameters are invalid,
82+
500 for other failures during retrieval.
7583
"""
76-
return repository.get_entities(queries)
84+
try:
85+
return repository.get_entities(query_params)
86+
except ValueError as exception:
87+
raise HTTPException(status_code=422, detail=str(exception)) from exception
88+
except Exception as exception:
89+
raise HTTPException(status_code=500, detail=str(exception)) from exception
7790

7891

7992
@read_router.get(

0 commit comments

Comments
 (0)