Skip to content

Commit de2620d

Browse files
feat(search): pagination + search for relationships and entities
1 parent 841e8fe commit de2620d

File tree

9 files changed

+292
-33
lines changed

9 files changed

+292
-33
lines changed

src/adapters/graph.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010

1111
from typing import Optional, Tuple
1212
from src.adapters.interfaces.graph import GraphClient
13-
from src.constants.kg import IdentificationParams, Node, Predicate, Triple
13+
from src.constants.kg import (
14+
IdentificationParams,
15+
Node,
16+
Predicate,
17+
SearchEntitiesResult,
18+
SearchRelationshipsResult,
19+
Triple,
20+
)
1421

1522

1623
class GraphAdapter:
@@ -190,17 +197,47 @@ def get_connected_nodes(
190197
node=node, uuids=uuids, limit=limit, with_labels=with_labels
191198
)
192199

193-
def search_relationships(self, limit: int = 10, skip: int = 0) -> list[Triple]:
200+
def search_relationships(
201+
self,
202+
limit: int = 10,
203+
skip: int = 0,
204+
relationship_types: Optional[list[str]] = None,
205+
from_node_labels: Optional[list[str]] = None,
206+
to_node_labels: Optional[list[str]] = None,
207+
query_text: Optional[str] = None,
208+
query_search_target: Optional[str] = "all",
209+
) -> SearchRelationshipsResult:
194210
"""
195211
Search the relationships of the graph.
196212
"""
197-
return self.graph.search_relationships(limit, skip)
213+
relationship_uuids = []
214+
# TODO: semantic search + src/core/agents/tools/kg_agent/KGAgentAddTripletsTool.py:165
215+
return self.graph.search_relationships(
216+
limit,
217+
skip,
218+
relationship_types,
219+
from_node_labels,
220+
to_node_labels,
221+
relationship_uuids,
222+
query_text,
223+
query_search_target,
224+
)
198225

199-
def search_entities(self, limit: int = 10, skip: int = 0) -> list[Node]:
226+
def search_entities(
227+
self,
228+
limit: int = 10,
229+
skip: int = 0,
230+
node_labels: Optional[list[str]] = None,
231+
query_text: Optional[str] = None,
232+
) -> SearchEntitiesResult:
200233
"""
201234
Search the entities of the graph.
202235
"""
203-
return self.graph.search_entities(limit, skip)
236+
node_uuids = []
237+
# TODO: semantic search
238+
return self.graph.search_entities(
239+
limit, skip, node_labels, node_uuids, query_text
240+
)
204241

205242

206243
_graph_adapter = GraphAdapter()

src/adapters/interfaces/graph.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from abc import ABC, abstractmethod
1212
from typing import Optional, Tuple
1313

14-
from src.constants.kg import IdentificationParams, Node, Predicate, Triple
14+
from src.constants.kg import (
15+
IdentificationParams,
16+
Node,
17+
Predicate,
18+
SearchEntitiesResult,
19+
SearchRelationshipsResult,
20+
)
1521

1622

1723
class GraphClient(ABC):
@@ -196,14 +202,34 @@ def get_connected_nodes(
196202
raise NotImplementedError("get_connected_nodes method not implemented")
197203

198204
@abstractmethod
199-
def search_relationships(self, limit: int = 10, skip: int = 0) -> list[Triple]:
205+
def search_relationships(
206+
self,
207+
limit: int = 10,
208+
skip: int = 0,
209+
relationship_types: Optional[list[str]] = None,
210+
from_node_labels: Optional[list[str]] = None,
211+
to_node_labels: Optional[list[str]] = None,
212+
relationship_uuids: Optional[list[str]] = None,
213+
query_text: Optional[str] = None,
214+
query_search_target: Optional[
215+
str
216+
] = "all", # Search into the relationship desc or node names or relationship desc
217+
) -> SearchRelationshipsResult:
200218
"""
201219
Search the relationships of the graph.
202220
"""
203221
raise NotImplementedError("search_relationships method not implemented")
204222

205223
@abstractmethod
206-
def search_entities(self, limit: int = 10, skip: int = 0) -> list[Node]:
224+
def search_entities(
225+
self,
226+
limit: int = 10,
227+
skip: int = 0,
228+
node_labels: Optional[list[str]] = None,
229+
node_uuids: Optional[list[str]] = None,
230+
query_text: Optional[str] = None,
231+
) -> SearchEntitiesResult:
207232
"""
208233
Search the entities of the graph.
209234
"""
235+
raise NotImplementedError("search_entities method not implemented")

src/constants/kg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,21 @@ class IdentificationParams(BaseModel):
111111
)
112112

113113
model_config = ConfigDict(extra="allow")
114+
115+
116+
class SearchRelationshipsResult(BaseModel):
117+
"""
118+
Search relationships result model.
119+
"""
120+
121+
results: List[Triple]
122+
total: int
123+
124+
125+
class SearchEntitiesResult(BaseModel):
126+
"""
127+
Search entities result model.
128+
"""
129+
130+
results: List[Node]
131+
total: int

src/core/agents/tools/kg_agent/KGAgentAddTripletsTool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ def _run(self, *args, **kwargs) -> str:
162162
triplet.object,
163163
)
164164

165+
# TODO: save natural language relationship/predicate+s+o vector
166+
165167
# TODO: add changelog relationship created
166168

167169
return f"Triplets added successfully: {triplets}"

src/core/search/entities.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
-----
99
"""
1010

11-
from src.constants.kg import Node
11+
from typing import Optional
12+
from src.constants.kg import SearchEntitiesResult
1213
from src.services.kg_agent.main import graph_adapter
1314

1415

15-
def search_entities(limit: int = 10, skip: int = 0) -> list[Node]:
16+
def search_entities(
17+
limit: int = 10,
18+
skip: int = 0,
19+
node_labels: Optional[list[str]] = None,
20+
query_text: Optional[str] = None,
21+
) -> SearchEntitiesResult:
1622
"""
1723
Search the entities of the graph.
1824
"""
19-
result = graph_adapter.search_entities(limit, skip)
25+
result = graph_adapter.search_entities(limit, skip, node_labels, query_text)
2026
return result

src/core/search/relationships.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,30 @@
88
-----
99
"""
1010

11-
from src.constants.kg import Triple
11+
from typing import Optional
12+
from src.constants.kg import SearchRelationshipsResult
1213
from src.services.kg_agent.main import graph_adapter
1314

1415

15-
def search_relationships(limit: int = 10, skip: int = 0) -> list[Triple]:
16+
def search_relationships(
17+
limit: int = 10,
18+
skip: int = 0,
19+
relationship_types: Optional[list[str]] = None,
20+
from_node_labels: Optional[list[str]] = None,
21+
to_node_labels: Optional[list[str]] = None,
22+
query_text: Optional[str] = None,
23+
query_search_target: Optional[str] = "all",
24+
) -> SearchRelationshipsResult:
1625
"""
1726
Search the relationships of the graph.
1827
"""
19-
result = graph_adapter.search_relationships(limit, skip)
28+
result = graph_adapter.search_relationships(
29+
limit,
30+
skip,
31+
relationship_types,
32+
from_node_labels,
33+
to_node_labels,
34+
query_text,
35+
query_search_target,
36+
)
2037
return result

src/lib/neo4j/client.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
from neo4j import GraphDatabase
1414
from src.adapters.interfaces.graph import GraphClient
1515
from src.config import config
16-
from src.constants.kg import IdentificationParams, Node, Predicate, Triple
16+
from src.constants.kg import (
17+
IdentificationParams,
18+
Node,
19+
Predicate,
20+
SearchEntitiesResult,
21+
SearchRelationshipsResult,
22+
Triple,
23+
)
1724
from src.utils.logging import log
1825

1926

@@ -650,19 +657,79 @@ def get_connected_nodes(
650657
for record in result.records
651658
]
652659

653-
def search_relationships(self, limit: int = 10, skip: int = 0) -> list[Triple]:
660+
def search_relationships(
661+
self,
662+
limit: int = 10,
663+
skip: int = 0,
664+
relationship_types: Optional[list[str]] = None,
665+
from_node_labels: Optional[list[str]] = None,
666+
to_node_labels: Optional[list[str]] = None,
667+
relationship_uuids: Optional[list[str]] = None,
668+
query_text: Optional[str] = None,
669+
query_search_target: Optional[str] = "all",
670+
) -> SearchRelationshipsResult:
654671
"""
655672
Search the relationships of the graph.
656673
"""
674+
filters = []
675+
if relationship_types:
676+
filters.append(
677+
"type(r) IN [" + ",".join(f"'{t}'" for t in relationship_types) + "]"
678+
)
679+
if from_node_labels:
680+
filters.append(
681+
"ANY(lbl IN labels(n) WHERE lbl IN ["
682+
+ ",".join(
683+
f"'{label}'" for label in self._clean_labels(from_node_labels)
684+
)
685+
+ "])"
686+
)
687+
if to_node_labels:
688+
filters.append(
689+
"ANY(lbl IN labels(m) WHERE lbl IN ["
690+
+ ",".join(f"'{label}'" for label in self._clean_labels(to_node_labels))
691+
+ "])"
692+
)
693+
if relationship_uuids:
694+
filters.append(
695+
"r.uuid IN [" + ",".join(f"'{u}'" for u in relationship_uuids) + "]"
696+
)
697+
if query_text:
698+
if query_search_target == "all":
699+
filters.append(
700+
f"(toLower(coalesce(n.name, n.Name, '')) CONTAINS toLower('{query_text}') OR "
701+
f"toLower(coalesce(m.name, m.Name, '')) CONTAINS toLower('{query_text}') OR "
702+
f"toLower(coalesce(r.description, r.Description, '')) CONTAINS toLower('{query_text}'))"
703+
)
704+
elif query_search_target == "node_name":
705+
filters.append(f"toLower(n.name) CONTAINS toLower('{query_text}')")
706+
elif query_search_target == "relationship_description":
707+
filters.append(
708+
f"toLower(r.description) CONTAINS toLower('{query_text}')"
709+
)
710+
elif query_search_target == "relationship_name":
711+
filters.append(f"toLower(r.name) CONTAINS toLower('{query_text}')")
657712
cypher_query = f"""
658-
MATCH (n)-[r]-(m)
659-
RETURN n.uuid as n_uuid, n.name as n_name, labels(n) as n_labels, n.description as n_description, properties(n) as n_properties,
660-
r as rel, type(r) as rel_type, r.description as rel_description,
661-
m.uuid as m_uuid, m.name as m_name, labels(m) as m_labels, m.description as m_description, properties(m) as m_properties
713+
MATCH (n)-[r]->(m)
714+
{"WHERE " + " AND ".join(filters) if filters else ""}
715+
RETURN n.uuid AS n_uuid, n.name AS n_name, labels(n) AS n_labels,
716+
n.description AS n_description, properties(n) AS n_properties,
717+
r AS rel, type(r) AS rel_type, r.description AS rel_description,
718+
m.uuid AS m_uuid, m.name AS m_name, labels(m) AS m_labels,
719+
m.description AS m_description, properties(m) AS m_properties
662720
SKIP {skip}
663721
LIMIT {limit}
664722
"""
723+
cypher_count = """
724+
MATCH ()-[r]-()
725+
RETURN count(r) AS total
726+
"""
665727
result = self.driver.execute_query(cypher_query)
728+
count_result = self.driver.execute_query(cypher_count)
729+
total = 0
730+
if count_result and count_result.records:
731+
total = count_result.records[0].get("total") or 0
732+
666733
triples: list[Triple] = []
667734
for record in result.records:
668735
relationship = record.get("rel")
@@ -745,19 +812,49 @@ def search_relationships(self, limit: int = 10, skip: int = 0) -> list[Triple]:
745812
),
746813
)
747814
)
748-
return triples
815+
return SearchRelationshipsResult(results=triples, total=total)
749816

750-
def search_entities(self, limit: int = 10, skip: int = 0) -> list[Node]:
817+
def search_entities(
818+
self,
819+
limit: int = 10,
820+
skip: int = 0,
821+
node_labels: Optional[list[str]] = None,
822+
node_uuids: Optional[list[str]] = None,
823+
query_text: Optional[str] = None,
824+
) -> SearchEntitiesResult:
751825
"""
752826
Search the entities of the graph.
753827
"""
828+
filters = []
829+
if node_labels:
830+
filters.append(
831+
"ANY(lbl IN labels(n) WHERE lbl IN ["
832+
+ ",".join(f"'{label}'" for label in self._clean_labels(node_labels))
833+
+ "])"
834+
)
835+
if node_uuids:
836+
filters.append(f"n.uuid IN [{','.join(node_uuids)}]")
837+
if query_text:
838+
filters.append(
839+
f"(toLower(coalesce(n.name, n.Name, '')) CONTAINS toLower('{query_text}'))"
840+
)
754841
cypher_query = f"""
755842
MATCH (n)
843+
{"WHERE " + " AND ".join(filters) if filters else ""}
756844
RETURN n.uuid as uuid, n.name as name, labels(n) as labels, n.description as description, properties(n) as properties
757845
SKIP {skip}
758846
LIMIT {limit}
759847
"""
848+
cypher_count = """
849+
MATCH (n)
850+
RETURN count(n) AS total
851+
"""
760852
result = self.driver.execute_query(cypher_query)
853+
count_result = self.driver.execute_query(cypher_count)
854+
total = 0
855+
if count_result and count_result.records:
856+
total = count_result.records[0].get("total") or 0
857+
761858
nodes: list[Node] = []
762859
for record in result.records:
763860
properties_record = record.get("properties") or {}
@@ -784,7 +881,7 @@ def search_entities(self, limit: int = 10, skip: int = 0) -> list[Node]:
784881
properties=properties,
785882
)
786883
)
787-
return nodes
884+
return SearchEntitiesResult(results=nodes, total=total)
788885

789886

790887
_neo4j_client = Neo4jClient()

0 commit comments

Comments
 (0)