From 7661da328e3167234d78028f6772cd4cbeec6217 Mon Sep 17 00:00:00 2001 From: Marko Budiselic Date: Mon, 1 Sep 2025 15:29:31 +0200 Subject: [PATCH 1/3] Add node neighborhood tool --- memgraph-toolbox/README.md | 2 + .../src/memgraph_toolbox/memgraph_toolbox.py | 2 + .../memgraph_toolbox/tests/test_toolbox.py | 3 +- .../src/memgraph_toolbox/tests/test_tools.py | 20 +++ .../src/memgraph_toolbox/tools/__init__.py | 1 + .../tools/node_neighborhood.py | 140 ++++++++++++++++++ 6 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py diff --git a/memgraph-toolbox/README.md b/memgraph-toolbox/README.md index 3602819..842bfff 100644 --- a/memgraph-toolbox/README.md +++ b/memgraph-toolbox/README.md @@ -19,6 +19,8 @@ Below is a list of tools included in the toolbox, along with their descriptions: 7. `CypherTool` - Executes arbitrary [Cypher queries](https://memgraph.com/docs/querying) on a Memgraph database. 8. `ShowConstraintInfoTool` - Shows [constraint](https://memgraph.com/docs/fundamentals/constraints) information from a Memgraph database. 9. `ShowConfigTool` - Shows [configuration](https://memgraph.com/docs/database-management/configuration) information from a Memgraph database. +10. `NodeVectorSearchTool` - Searches the most similar nodes using the Memgraph's [vector search](https://memgraph.com/docs/querying/vector-search). +11. `NodeNeighborhoodTool` - Searches for the data attached to a given node using Memgraph's [deep-path traversals](https://memgraph.com/docs/advanced-algorithms/deep-path-traversal). ## Usage diff --git a/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py b/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py index 6602a97..0c0504f 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py +++ b/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py @@ -7,6 +7,7 @@ from .tools.constraint import ShowConstraintInfoTool from .tools.cypher import CypherTool from .tools.index import ShowIndexInfoTool +from .tools.node_neighborhood import NodeNeighborhoodTool from .tools.node_vector_search import NodeVectorSearchTool from .tools.page_rank import PageRankTool from .tools.schema import ShowSchemaInfoTool @@ -37,6 +38,7 @@ def __init__(self, db: Memgraph): self.add_tool(ShowConstraintInfoTool(db)) self.add_tool(CypherTool(db)) self.add_tool(ShowIndexInfoTool(db)) + self.add_tool(NodeNeighborhoodTool(db)) self.add_tool(NodeVectorSearchTool(db)) self.add_tool(PageRankTool(db)) self.add_tool(ShowSchemaInfoTool(db)) diff --git a/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py b/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py index d9715dc..0c42513 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py @@ -51,7 +51,7 @@ def test_memgraph_toolbox(): tools = toolkit.get_all_tools() # Check if we have all 9 tools - assert len(tools) == 10 + assert len(tools) == 11 # Check for specific tool names tool_names = [tool.name for tool in tools] @@ -66,6 +66,7 @@ def test_memgraph_toolbox(): "show_schema_info", "show_storage_info", "show_triggers", + "node_neighborhood", ] for expected_tool in expected_tools: diff --git a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py index 5a5930b..3ce579d 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py @@ -6,6 +6,7 @@ from ..tools.constraint import ShowConstraintInfoTool from ..tools.cypher import CypherTool from ..tools.index import ShowIndexInfoTool +from ..tools.node_neighborhood import NodeNeighborhoodTool from ..tools.node_vector_search import NodeVectorSearchTool from ..tools.page_rank import PageRankTool from ..tools.schema import ShowSchemaInfoTool @@ -282,3 +283,22 @@ def test_node_vector_search_tool(): 'MATCH (n:Person) WHERE "embedding" IN keys(n) DETACH DELETE n' ) memgraph_client.query("DROP VECTOR INDEX my_index") + + +def test_node_neighborhood_tool(): + """Test the NodeNeighborhood tool.""" + url = "bolt://localhost:7687" + user = "" + password = "" + memgraph_client = Memgraph(url=url, username=user, password=password) + + memgraph_client.query("CREATE (:Person {id: 1})") + memgraph_client.query("CREATE (:Person {id: 2})") + memgraph_client.query("CREATE (:Person {id: 3})") + + node_neighborhood_tool = NodeNeighborhoodTool(db=memgraph_client) + result = node_neighborhood_tool.call( + {"node_id": 1, "max_distance": 2, "relationship_types": ["KNOWS"]} + ) + assert isinstance(result, list) + assert len(result) == 2 diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py b/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py index e69de29..8b13789 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py new file mode 100644 index 0000000..4186675 --- /dev/null +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py @@ -0,0 +1,140 @@ +from typing import Any, Dict, List + +from ..api.memgraph import Memgraph +from ..api.tool import BaseTool + + +class NodeNeighborhoodTool(BaseTool): + """ + Tool for finding nodes within a specified neighborhood distance in Memgraph. + """ + + def __init__(self, db: Memgraph): + super().__init__( + name="node_neighborhood", + description=( + "Finds nodes within a specified distance from a given node. " + "This tool explores the graph neighborhood around a starting node, " + "returning all nodes and relationships found within the specified radius." + ), + input_schema={ + "type": "object", + "properties": { + "node_id": { + "type": "string", + "description": "The ID of the starting node to find neighborhood around", + }, + "max_distance": { + "type": "integer", + "description": "Maximum distance (hops) to search from the starting node. Default is 2.", + "default": 2, + }, + "relationship_types": { + "type": "array", + "items": {"type": "string"}, + "description": "List of relationship types to include in the search. If empty, all types are included.", + "default": [], + }, + "node_labels": { + "type": "array", + "items": {"type": "string"}, + "description": "List of node labels to include in the search. If empty, all labels are included.", + "default": [], + }, + "include_paths": { + "type": "boolean", + "description": "Whether to include the paths from start node to each neighbor. Default is false.", + "default": False, + }, + "limit": { + "type": "integer", + "description": "Maximum number of nodes to return. Default is 100.", + "default": 100, + }, + }, + "required": ["node_id"], + }, + ) + self.db = db + + def call(self, arguments: Dict[str, Any]) -> List[Dict[str, Any]]: + """Execute the neighborhood search and return the results.""" + node_id = arguments["node_id"] + max_distance = arguments.get("max_distance", 2) + relationship_types = arguments.get("relationship_types", []) + node_labels = arguments.get("node_labels", []) + include_paths = arguments.get("include_paths", False) + limit = arguments.get("limit", 100) + + # Build relationship type filter + rel_filter = "" + if relationship_types: + rel_types_str = ", ".join([f"'{rt}'" for rt in relationship_types]) + rel_filter = f"WHERE type(r) IN [{rel_types_str}]" + + # Build node label filter + label_filter = "" + if node_labels: + label_conditions = [ + f"ANY(label IN labels(n) WHERE label IN {node_labels})" + for _ in node_labels + ] + label_filter = f"AND {' AND '.join(label_conditions)}" + + if include_paths: + # Query with paths included + query = f""" + MATCH path = (start)-[*1..{max_distance}]-{rel_filter}(neighbor) + WHERE start.element_id = $node_id {label_filter} + WITH neighbor, path, length(path) as distance + RETURN DISTINCT neighbor, distance, path + ORDER BY distance, neighbor.element_id + LIMIT $limit + """ + else: + # Query without paths (more efficient) + query = f""" + MATCH (start)-[*1..{max_distance}]-{rel_filter}(neighbor) + WHERE start.element_id = $node_id {label_filter} + WITH DISTINCT neighbor, length(shortestPath((start)-[*]-(neighbor))) as distance + RETURN neighbor, distance + ORDER BY distance, neighbor.element_id + LIMIT $limit + """ + + params = {"node_id": node_id, "limit": limit} + + try: + results = self.db.query(query, params) + + # Process results to extract relevant information + processed_results = [] + for record in results: + node_data = { + "node_id": ( + record["neighbor"].element_id + if hasattr(record["neighbor"], "element_id") + else str(record["neighbor"]) + ), + "labels": ( + list(record["neighbor"].labels) + if hasattr(record["neighbor"], "labels") + else [] + ), + "properties": ( + dict(record["neighbor"]) + if hasattr(record["neighbor"], "__iter__") + else {} + ), + "distance": record["distance"], + } + + if include_paths and "path" in record: + node_data["path"] = str(record["path"]) + + processed_results.append(node_data) + + return processed_results + + except Exception as e: + return [{"error": f"Failed to find neighborhood: {str(e)}"}] From abc7abcaed11f4d5c333c48cc80c4f2f02840a80 Mon Sep 17 00:00:00 2001 From: Marko Budiselic Date: Mon, 1 Sep 2025 17:56:30 +0200 Subject: [PATCH 2/3] Fix impl and test --- .../src/memgraph_toolbox/tests/test_tools.py | 21 ++-- .../tools/node_neighborhood.py | 95 ++----------------- 2 files changed, 22 insertions(+), 94 deletions(-) diff --git a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py index 3ce579d..86453c8 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py @@ -292,13 +292,22 @@ def test_node_neighborhood_tool(): password = "" memgraph_client = Memgraph(url=url, username=user, password=password) - memgraph_client.query("CREATE (:Person {id: 1})") - memgraph_client.query("CREATE (:Person {id: 2})") - memgraph_client.query("CREATE (:Person {id: 3})") + label = "TestNodeNeighborhoodToolLabel" + memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;") + memgraph_client.query( + f"CREATE (p1:{label} {{id: 1}})-[:KNOWS]->(p2:{label} {{id: 2}}), (p2)-[:KNOWS]->(p3:{label} {{id: 3}});" + ) + memgraph_client.query( + f"CREATE (p4:{label} {{id: 4}})-[:KNOWS]->(p5:{label} {{id: 5}});" + ) + ids = memgraph_client.query( + f"MATCH (p1:{label} {{id:1}}) RETURN id(p1) AS node_id;" + ) + assert len(ids) == 1 + node_id = ids[0]["node_id"] node_neighborhood_tool = NodeNeighborhoodTool(db=memgraph_client) - result = node_neighborhood_tool.call( - {"node_id": 1, "max_distance": 2, "relationship_types": ["KNOWS"]} - ) + result = node_neighborhood_tool.call({"node_id": node_id, "max_distance": 2}) assert isinstance(result, list) assert len(result) == 2 + memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;") diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py index 4186675..981e23e 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py @@ -27,24 +27,7 @@ def __init__(self, db: Memgraph): "max_distance": { "type": "integer", "description": "Maximum distance (hops) to search from the starting node. Default is 2.", - "default": 2, - }, - "relationship_types": { - "type": "array", - "items": {"type": "string"}, - "description": "List of relationship types to include in the search. If empty, all types are included.", - "default": [], - }, - "node_labels": { - "type": "array", - "items": {"type": "string"}, - "description": "List of node labels to include in the search. If empty, all labels are included.", - "default": [], - }, - "include_paths": { - "type": "boolean", - "description": "Whether to include the paths from start node to each neighbor. Default is false.", - "default": False, + "default": 1, }, "limit": { "type": "integer", @@ -60,81 +43,17 @@ def __init__(self, db: Memgraph): def call(self, arguments: Dict[str, Any]) -> List[Dict[str, Any]]: """Execute the neighborhood search and return the results.""" node_id = arguments["node_id"] - max_distance = arguments.get("max_distance", 2) - relationship_types = arguments.get("relationship_types", []) - node_labels = arguments.get("node_labels", []) - include_paths = arguments.get("include_paths", False) + max_distance = arguments.get("max_distance", 1) limit = arguments.get("limit", 100) - # Build relationship type filter - rel_filter = "" - if relationship_types: - rel_types_str = ", ".join([f"'{rt}'" for rt in relationship_types]) - rel_filter = f"WHERE type(r) IN [{rel_types_str}]" - - # Build node label filter - label_filter = "" - if node_labels: - label_conditions = [ - f"ANY(label IN labels(n) WHERE label IN {node_labels})" - for _ in node_labels - ] - label_filter = f"AND {' AND '.join(label_conditions)}" - - if include_paths: - # Query with paths included - query = f""" - MATCH path = (start)-[*1..{max_distance}]-{rel_filter}(neighbor) - WHERE start.element_id = $node_id {label_filter} - WITH neighbor, path, length(path) as distance - RETURN DISTINCT neighbor, distance, path - ORDER BY distance, neighbor.element_id - LIMIT $limit - """ - else: - # Query without paths (more efficient) - query = f""" - MATCH (start)-[*1..{max_distance}]-{rel_filter}(neighbor) - WHERE start.element_id = $node_id {label_filter} - WITH DISTINCT neighbor, length(shortestPath((start)-[*]-(neighbor))) as distance - RETURN neighbor, distance - ORDER BY distance, neighbor.element_id - LIMIT $limit - """ - - params = {"node_id": node_id, "limit": limit} - + query = f"""MATCH (n)-[r*..{max_distance}]-(m) WHERE id(n) = {node_id} RETURN DISTINCT m LIMIT {limit};""" try: - results = self.db.query(query, params) - - # Process results to extract relevant information + results = self.db.query(query, {}) processed_results = [] for record in results: - node_data = { - "node_id": ( - record["neighbor"].element_id - if hasattr(record["neighbor"], "element_id") - else str(record["neighbor"]) - ), - "labels": ( - list(record["neighbor"].labels) - if hasattr(record["neighbor"], "labels") - else [] - ), - "properties": ( - dict(record["neighbor"]) - if hasattr(record["neighbor"], "__iter__") - else {} - ), - "distance": record["distance"], - } - - if include_paths and "path" in record: - node_data["path"] = str(record["path"]) - - processed_results.append(node_data) - + node_data = record["m"]; + properties = {k: v for k, v in node_data.items()} + processed_results.append(properties) return processed_results - except Exception as e: return [{"error": f"Failed to find neighborhood: {str(e)}"}] From c547e6ecddb68414c4b2706f18f98f003800327b Mon Sep 17 00:00:00 2001 From: Marko Budiselic Date: Tue, 2 Sep 2025 12:28:12 +0200 Subject: [PATCH 3/3] Fix the default in description --- .../src/memgraph_toolbox/tools/node_neighborhood.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py index 981e23e..ced337e 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py @@ -26,7 +26,7 @@ def __init__(self, db: Memgraph): }, "max_distance": { "type": "integer", - "description": "Maximum distance (hops) to search from the starting node. Default is 2.", + "description": "Maximum distance (hops) to search from the starting node. Default is 1.", "default": 1, }, "limit": {