Skip to content

Commit 807d8ee

Browse files
authored
Memory - add annotations, format return data, improve error handling (#113)
* move Neo4jMemory class to separate file, add ToolResult and structured output * update changelog, add try-catch for neo4j errors * Update CHANGELOG.md * add tool annotations, update tool arg types * replace List with list type hints
1 parent 06b4d4f commit 807d8ee

File tree

3 files changed

+420
-287
lines changed

3 files changed

+420
-287
lines changed

servers/mcp-neo4j-memory/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33
### Fixed
44

55
### Changed
6+
* Update tool return type hints for structured output
7+
* Move `Neo4jMemory` class and related classes to separate file
8+
* Change tool responses to return the `ToolResponse` object
9+
* Updated tool argument types with Pydantic models
610

711
### Added
12+
* Add structured output to tool responses
13+
* Add error handling to catch Neo4j specific errors and improve error responses
14+
* Implement `ToolError` class from FastMCP
15+
* Add tool annotations
816

917
## v0.2.0
1018

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import logging
2+
from typing import Any, Dict, List
3+
4+
from neo4j import AsyncDriver, RoutingControl
5+
from pydantic import BaseModel
6+
7+
8+
# Set up logging
9+
logger = logging.getLogger('mcp_neo4j_memory')
10+
logger.setLevel(logging.INFO)
11+
12+
# Models for our knowledge graph
13+
class Entity(BaseModel):
14+
name: str
15+
type: str
16+
observations: List[str]
17+
18+
class Relation(BaseModel):
19+
source: str
20+
target: str
21+
relationType: str
22+
23+
class KnowledgeGraph(BaseModel):
24+
entities: List[Entity]
25+
relations: List[Relation]
26+
27+
class ObservationAddition(BaseModel):
28+
entityName: str
29+
observations: List[str]
30+
31+
class ObservationDeletion(BaseModel):
32+
entityName: str
33+
observations: List[str]
34+
35+
class Neo4jMemory:
36+
def __init__(self, neo4j_driver: AsyncDriver):
37+
self.driver = neo4j_driver
38+
39+
async def create_fulltext_index(self):
40+
"""Create a fulltext search index for entities if it doesn't exist."""
41+
try:
42+
query = "CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations];"
43+
await self.driver.execute_query(query, routing_control=RoutingControl.WRITE)
44+
logger.info("Created fulltext search index")
45+
except Exception as e:
46+
# Index might already exist, which is fine
47+
logger.debug(f"Fulltext index creation: {e}")
48+
49+
async def load_graph(self, filter_query: str = "*"):
50+
"""Load the entire knowledge graph from Neo4j."""
51+
logger.info("Loading knowledge graph from Neo4j")
52+
query = """
53+
CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score
54+
OPTIONAL MATCH (entity)-[r]-(other)
55+
RETURN collect(distinct {
56+
name: entity.name,
57+
type: entity.type,
58+
observations: entity.observations
59+
}) as nodes,
60+
collect(distinct {
61+
source: startNode(r).name,
62+
target: endNode(r).name,
63+
relationType: type(r)
64+
}) as relations
65+
"""
66+
67+
result = await self.driver.execute_query(query, {"filter": filter_query}, routing_control=RoutingControl.READ)
68+
69+
if not result.records:
70+
return KnowledgeGraph(entities=[], relations=[])
71+
72+
record = result.records[0]
73+
nodes = record.get('nodes', list())
74+
rels = record.get('relations', list())
75+
76+
entities = [
77+
Entity(
78+
name=node['name'],
79+
type=node['type'],
80+
observations=node.get('observations', list())
81+
)
82+
for node in nodes if node.get('name')
83+
]
84+
85+
relations = [
86+
Relation(
87+
source=rel['source'],
88+
target=rel['target'],
89+
relationType=rel['relationType']
90+
)
91+
for rel in rels if rel.get('relationType')
92+
]
93+
94+
logger.debug(f"Loaded entities: {entities}")
95+
logger.debug(f"Loaded relations: {relations}")
96+
97+
return KnowledgeGraph(entities=entities, relations=relations)
98+
99+
async def create_entities(self, entities: List[Entity]) -> List[Entity]:
100+
"""Create multiple new entities in the knowledge graph."""
101+
logger.info(f"Creating {len(entities)} entities")
102+
for entity in entities:
103+
query = f"""
104+
WITH $entity as entity
105+
MERGE (e:Memory {{ name: entity.name }})
106+
SET e += entity {{ .type, .observations }}
107+
SET e:{entity.type}
108+
"""
109+
await self.driver.execute_query(query, {"entity": entity.model_dump()}, routing_control=RoutingControl.WRITE)
110+
111+
return entities
112+
113+
async def create_relations(self, relations: List[Relation]) -> List[Relation]:
114+
"""Create multiple new relations between entities."""
115+
logger.info(f"Creating {len(relations)} relations")
116+
for relation in relations:
117+
query = f"""
118+
WITH $relation as relation
119+
MATCH (from:Memory),(to:Memory)
120+
WHERE from.name = relation.source
121+
AND to.name = relation.target
122+
MERGE (from)-[r:{relation.relationType}]->(to)
123+
"""
124+
125+
await self.driver.execute_query(
126+
query,
127+
{"relation": relation.model_dump()},
128+
routing_control=RoutingControl.WRITE
129+
)
130+
131+
return relations
132+
133+
async def add_observations(self, observations: List[ObservationAddition]) -> List[Dict[str, Any]]:
134+
"""Add new observations to existing entities."""
135+
logger.info(f"Adding observations to {len(observations)} entities")
136+
query = """
137+
UNWIND $observations as obs
138+
MATCH (e:Memory { name: obs.entityName })
139+
WITH e, [o in obs.observations WHERE NOT o IN e.observations] as new
140+
SET e.observations = coalesce(e.observations,[]) + new
141+
RETURN e.name as name, new
142+
"""
143+
144+
result = await self.driver.execute_query(
145+
query,
146+
{"observations": [obs.model_dump() for obs in observations]},
147+
routing_control=RoutingControl.WRITE
148+
)
149+
150+
results = [{"entityName": record.get("name"), "addedObservations": record.get("new")} for record in result.records]
151+
return results
152+
153+
async def delete_entities(self, entity_names: List[str]) -> None:
154+
"""Delete multiple entities and their associated relations."""
155+
logger.info(f"Deleting {len(entity_names)} entities")
156+
query = """
157+
UNWIND $entities as name
158+
MATCH (e:Memory { name: name })
159+
DETACH DELETE e
160+
"""
161+
162+
await self.driver.execute_query(query, {"entities": entity_names}, routing_control=RoutingControl.WRITE)
163+
logger.info(f"Successfully deleted {len(entity_names)} entities")
164+
165+
async def delete_observations(self, deletions: List[ObservationDeletion]) -> None:
166+
"""Delete specific observations from entities."""
167+
logger.info(f"Deleting observations from {len(deletions)} entities")
168+
query = """
169+
UNWIND $deletions as d
170+
MATCH (e:Memory { name: d.entityName })
171+
SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations]
172+
"""
173+
await self.driver.execute_query(
174+
query,
175+
{"deletions": [deletion.model_dump() for deletion in deletions]},
176+
routing_control=RoutingControl.WRITE
177+
)
178+
logger.info(f"Successfully deleted observations from {len(deletions)} entities")
179+
180+
async def delete_relations(self, relations: List[Relation]) -> None:
181+
"""Delete multiple relations from the graph."""
182+
logger.info(f"Deleting {len(relations)} relations")
183+
for relation in relations:
184+
query = f"""
185+
WITH $relation as relation
186+
MATCH (source:Memory)-[r:{relation.relationType}]->(target:Memory)
187+
WHERE source.name = relation.source
188+
AND target.name = relation.target
189+
DELETE r
190+
"""
191+
await self.driver.execute_query(
192+
query,
193+
{"relation": relation.model_dump()},
194+
routing_control=RoutingControl.WRITE
195+
)
196+
logger.info(f"Successfully deleted {len(relations)} relations")
197+
198+
async def read_graph(self) -> KnowledgeGraph:
199+
"""Read the entire knowledge graph."""
200+
return await self.load_graph()
201+
202+
async def search_memories(self, query: str) -> KnowledgeGraph:
203+
"""Search for memories based on a query with Fulltext Search."""
204+
logger.info(f"Searching for memories with query: '{query}'")
205+
return await self.load_graph(query)
206+
207+
async def find_memories_by_name(self, names: List[str]) -> KnowledgeGraph:
208+
"""Find specific memories by their names. This does not use fulltext search."""
209+
logger.info(f"Finding {len(names)} memories by name")
210+
query = """
211+
MATCH (e:Memory)
212+
WHERE e.name IN $names
213+
RETURN e.name as name,
214+
e.type as type,
215+
e.observations as observations
216+
"""
217+
result_nodes = await self.driver.execute_query(query, {"names": names}, routing_control=RoutingControl.READ)
218+
entities: list[Entity] = list()
219+
for record in result_nodes.records:
220+
entities.append(Entity(
221+
name=record['name'],
222+
type=record['type'],
223+
observations=record.get('observations', list())
224+
))
225+
226+
# Get relations for found entities
227+
relations: list[Relation] = list()
228+
if entities:
229+
query = """
230+
MATCH (source:Memory)-[r]->(target:Memory)
231+
WHERE source.name IN $names OR target.name IN $names
232+
RETURN source.name as source,
233+
target.name as target,
234+
type(r) as relationType
235+
"""
236+
result_relations = await self.driver.execute_query(query, {"names": names}, routing_control=RoutingControl.READ)
237+
for record in result_relations.records:
238+
relations.append(Relation(
239+
source=record["source"],
240+
target=record["target"],
241+
relationType=record["relationType"]
242+
))
243+
244+
logger.info(f"Found {len(entities)} entities and {len(relations)} relations")
245+
return KnowledgeGraph(entities=entities, relations=relations)

0 commit comments

Comments
 (0)