Skip to content

Commit 88f76ee

Browse files
Implement node links endpoint
1 parent bb1cf3a commit 88f76ee

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from .models import NodeLinks
12
from .registry import NodeModelRegistry
23

34
__all__ = [
5+
'NodeLinks',
46
'NodeModelRegistry',
57
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pydantic as pdt
2+
from aiida.orm import Node
3+
4+
5+
class NodeLinks(Node.Model):
6+
link_label: str = pdt.Field(description='The label of the link to the node.')
7+
link_type: str = pdt.Field(description='The type of the link to the node.')

aiida_restapi/repository/node.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
)
1818
from aiida.repository import File
1919

20+
from aiida_restapi.common.pagination import PaginatedResults
21+
from aiida_restapi.common.query import QueryParams
2022
from aiida_restapi.common.types import NodeModelType, NodeType
23+
from aiida_restapi.models.node import NodeLinks
2124

2225
from .entity import EntityRepository
2326

@@ -142,6 +145,47 @@ def get_node_attributes(self, node_id: int) -> dict[str, t.Any]:
142145
).first()[0],
143146
)
144147

148+
def get_node_links(
149+
self,
150+
node_id: int,
151+
queries: QueryParams,
152+
direction: t.Literal['incoming', 'outgoing'],
153+
) -> PaginatedResults[NodeLinks]:
154+
"""Get the incoming links of a node.
155+
156+
:param node_id: The id of the node to retrieve the incoming links for.
157+
:param queries: The query parameters, including filters, order_by, page_size, and page.
158+
:param direction: Specify whether to retrieve incoming or outgoing links.
159+
:return: The paginated requested linked nodes.
160+
"""
161+
node = self.entity_class.collection.get(pk=node_id)
162+
163+
start, end = (
164+
queries.page_size * (queries.page - 1),
165+
queries.page_size * queries.page,
166+
)
167+
168+
if direction == 'incoming':
169+
link_collection = node.base.links.get_incoming()
170+
else:
171+
link_collection = node.base.links.get_outgoing()
172+
173+
links: list[NodeLinks] = []
174+
for link in link_collection.all()[start:end]:
175+
link_params = link.node.serialize(minimal=True) | {
176+
'link_label': link.link_label,
177+
'link_type': link.link_type.value,
178+
}
179+
link_model = NodeLinks(**link_params)
180+
links.append(link_model)
181+
182+
return PaginatedResults(
183+
total=len(links),
184+
page=queries.page,
185+
page_size=queries.page_size,
186+
results=links,
187+
)
188+
145189
def create_entity(
146190
self,
147191
model: NodeModelType,

aiida_restapi/routers/nodes.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from aiida_restapi.common.pagination import PaginatedResults
1818
from aiida_restapi.common.query import QueryParams, query_params
1919
from aiida_restapi.models.node import NodeModelRegistry
20-
from aiida_restapi.repository.node import NodeRepository
20+
from aiida_restapi.repository.node import NodeLinks, NodeRepository
2121

2222
from .auth import UserInDB, get_current_active_user
2323

@@ -281,6 +281,37 @@ async def get_node_extras(node_id: int) -> dict[str, t.Any]:
281281
raise HTTPException(status_code=500, detail=str(err)) from err
282282

283283

284+
@read_router.get(
285+
'/nodes/{node_id}/links',
286+
response_model=PaginatedResults[NodeLinks],
287+
response_model_exclude_none=True,
288+
response_model_exclude_unset=True,
289+
)
290+
@with_dbenv()
291+
async def get_node_links(
292+
node_id: int,
293+
queries: t.Annotated[QueryParams, Depends(query_params)],
294+
direction: t.Literal['incoming', 'outgoing'] = Query(
295+
description='Specify whether to retrieve incoming or outgoing links.',
296+
),
297+
) -> PaginatedResults[NodeLinks]:
298+
"""Get the incoming or outgoing links of a node.
299+
300+
:param node_id: The id of the node to retrieve the incoming links for.
301+
:param queries: The query parameters, including filters, order_by, page_size, and page.
302+
:param direction: Specify whether to retrieve incoming or outgoing links.
303+
:return: The paginated requested linked nodes.
304+
:raises HTTPException: 404 if the node with the given id does not exist,
305+
500 for other failures during retrieval.
306+
"""
307+
try:
308+
return repository.get_node_links(node_id, queries, direction=direction)
309+
except NotExistent:
310+
raise HTTPException(status_code=404, detail=f'Could not find any node with id {node_id}')
311+
except Exception as err:
312+
raise HTTPException(status_code=500, detail=str(err)) from err
313+
314+
284315
@read_router.get('/nodes/{node_id}/download')
285316
@with_dbenv()
286317
async def download_node(

0 commit comments

Comments
 (0)