1313from neo4j import GraphDatabase
1414from src .adapters .interfaces .graph import GraphClient
1515from src .config import config
16- from src .constants .kg import IdentificationParams , Node , Predicate , Triple
16+ from src .constants .kg import (
17+ IdentificationParams ,
18+ Node ,
19+ Predicate ,
20+ SearchEntitiesResult ,
21+ SearchRelationshipsResult ,
22+ Triple ,
23+ )
1724from src .utils .logging import log
1825
1926
@@ -650,19 +657,79 @@ def get_connected_nodes(
650657 for record in result .records
651658 ]
652659
653- def search_relationships (self , limit : int = 10 , skip : int = 0 ) -> list [Triple ]:
660+ def search_relationships (
661+ self ,
662+ limit : int = 10 ,
663+ skip : int = 0 ,
664+ relationship_types : Optional [list [str ]] = None ,
665+ from_node_labels : Optional [list [str ]] = None ,
666+ to_node_labels : Optional [list [str ]] = None ,
667+ relationship_uuids : Optional [list [str ]] = None ,
668+ query_text : Optional [str ] = None ,
669+ query_search_target : Optional [str ] = "all" ,
670+ ) -> SearchRelationshipsResult :
654671 """
655672 Search the relationships of the graph.
656673 """
674+ filters = []
675+ if relationship_types :
676+ filters .append (
677+ "type(r) IN [" + "," .join (f"'{ t } '" for t in relationship_types ) + "]"
678+ )
679+ if from_node_labels :
680+ filters .append (
681+ "ANY(lbl IN labels(n) WHERE lbl IN ["
682+ + "," .join (
683+ f"'{ label } '" for label in self ._clean_labels (from_node_labels )
684+ )
685+ + "])"
686+ )
687+ if to_node_labels :
688+ filters .append (
689+ "ANY(lbl IN labels(m) WHERE lbl IN ["
690+ + "," .join (f"'{ label } '" for label in self ._clean_labels (to_node_labels ))
691+ + "])"
692+ )
693+ if relationship_uuids :
694+ filters .append (
695+ "r.uuid IN [" + "," .join (f"'{ u } '" for u in relationship_uuids ) + "]"
696+ )
697+ if query_text :
698+ if query_search_target == "all" :
699+ filters .append (
700+ f"(toLower(coalesce(n.name, n.Name, '')) CONTAINS toLower('{ query_text } ') OR "
701+ f"toLower(coalesce(m.name, m.Name, '')) CONTAINS toLower('{ query_text } ') OR "
702+ f"toLower(coalesce(r.description, r.Description, '')) CONTAINS toLower('{ query_text } '))"
703+ )
704+ elif query_search_target == "node_name" :
705+ filters .append (f"toLower(n.name) CONTAINS toLower('{ query_text } ')" )
706+ elif query_search_target == "relationship_description" :
707+ filters .append (
708+ f"toLower(r.description) CONTAINS toLower('{ query_text } ')"
709+ )
710+ elif query_search_target == "relationship_name" :
711+ filters .append (f"toLower(r.name) CONTAINS toLower('{ query_text } ')" )
657712 cypher_query = f"""
658- MATCH (n)-[r]-(m)
659- RETURN n.uuid as n_uuid, n.name as n_name, labels(n) as n_labels, n.description as n_description, properties(n) as n_properties,
660- r as rel, type(r) as rel_type, r.description as rel_description,
661- m.uuid as m_uuid, m.name as m_name, labels(m) as m_labels, m.description as m_description, properties(m) as m_properties
713+ MATCH (n)-[r]->(m)
714+ { "WHERE " + " AND " .join (filters ) if filters else "" }
715+ RETURN n.uuid AS n_uuid, n.name AS n_name, labels(n) AS n_labels,
716+ n.description AS n_description, properties(n) AS n_properties,
717+ r AS rel, type(r) AS rel_type, r.description AS rel_description,
718+ m.uuid AS m_uuid, m.name AS m_name, labels(m) AS m_labels,
719+ m.description AS m_description, properties(m) AS m_properties
662720 SKIP { skip }
663721 LIMIT { limit }
664722 """
723+ cypher_count = """
724+ MATCH ()-[r]-()
725+ RETURN count(r) AS total
726+ """
665727 result = self .driver .execute_query (cypher_query )
728+ count_result = self .driver .execute_query (cypher_count )
729+ total = 0
730+ if count_result and count_result .records :
731+ total = count_result .records [0 ].get ("total" ) or 0
732+
666733 triples : list [Triple ] = []
667734 for record in result .records :
668735 relationship = record .get ("rel" )
@@ -745,19 +812,49 @@ def search_relationships(self, limit: int = 10, skip: int = 0) -> list[Triple]:
745812 ),
746813 )
747814 )
748- return triples
815+ return SearchRelationshipsResult ( results = triples , total = total )
749816
750- def search_entities (self , limit : int = 10 , skip : int = 0 ) -> list [Node ]:
817+ def search_entities (
818+ self ,
819+ limit : int = 10 ,
820+ skip : int = 0 ,
821+ node_labels : Optional [list [str ]] = None ,
822+ node_uuids : Optional [list [str ]] = None ,
823+ query_text : Optional [str ] = None ,
824+ ) -> SearchEntitiesResult :
751825 """
752826 Search the entities of the graph.
753827 """
828+ filters = []
829+ if node_labels :
830+ filters .append (
831+ "ANY(lbl IN labels(n) WHERE lbl IN ["
832+ + "," .join (f"'{ label } '" for label in self ._clean_labels (node_labels ))
833+ + "])"
834+ )
835+ if node_uuids :
836+ filters .append (f"n.uuid IN [{ ',' .join (node_uuids )} ]" )
837+ if query_text :
838+ filters .append (
839+ f"(toLower(coalesce(n.name, n.Name, '')) CONTAINS toLower('{ query_text } '))"
840+ )
754841 cypher_query = f"""
755842 MATCH (n)
843+ { "WHERE " + " AND " .join (filters ) if filters else "" }
756844 RETURN n.uuid as uuid, n.name as name, labels(n) as labels, n.description as description, properties(n) as properties
757845 SKIP { skip }
758846 LIMIT { limit }
759847 """
848+ cypher_count = """
849+ MATCH (n)
850+ RETURN count(n) AS total
851+ """
760852 result = self .driver .execute_query (cypher_query )
853+ count_result = self .driver .execute_query (cypher_count )
854+ total = 0
855+ if count_result and count_result .records :
856+ total = count_result .records [0 ].get ("total" ) or 0
857+
761858 nodes : list [Node ] = []
762859 for record in result .records :
763860 properties_record = record .get ("properties" ) or {}
@@ -784,7 +881,7 @@ def search_entities(self, limit: int = 10, skip: int = 0) -> list[Node]:
784881 properties = properties ,
785882 )
786883 )
787- return nodes
884+ return SearchEntitiesResult ( results = nodes , total = total )
788885
789886
790887_neo4j_client = Neo4jClient ()
0 commit comments