1+ import logging
2+ from typing import Any , Dict , List
3+
4+ from neo4j import AsyncDriver , RoutingControl
5+ from pydantic import BaseModel
6+
7+
8+ # Set up logging
9+ logger = logging .getLogger ('mcp_neo4j_memory' )
10+ logger .setLevel (logging .INFO )
11+
12+ # Models for our knowledge graph
13+ class Entity (BaseModel ):
14+ name : str
15+ type : str
16+ observations : List [str ]
17+
18+ class Relation (BaseModel ):
19+ source : str
20+ target : str
21+ relationType : str
22+
23+ class KnowledgeGraph (BaseModel ):
24+ entities : List [Entity ]
25+ relations : List [Relation ]
26+
27+ class ObservationAddition (BaseModel ):
28+ entityName : str
29+ observations : List [str ]
30+
31+ class ObservationDeletion (BaseModel ):
32+ entityName : str
33+ observations : List [str ]
34+
35+ class Neo4jMemory :
36+ def __init__ (self , neo4j_driver : AsyncDriver ):
37+ self .driver = neo4j_driver
38+
39+ async def create_fulltext_index (self ):
40+ """Create a fulltext search index for entities if it doesn't exist."""
41+ try :
42+ query = "CREATE FULLTEXT INDEX search IF NOT EXISTS FOR (m:Memory) ON EACH [m.name, m.type, m.observations];"
43+ await self .driver .execute_query (query , routing_control = RoutingControl .WRITE )
44+ logger .info ("Created fulltext search index" )
45+ except Exception as e :
46+ # Index might already exist, which is fine
47+ logger .debug (f"Fulltext index creation: { e } " )
48+
49+ async def load_graph (self , filter_query : str = "*" ):
50+ """Load the entire knowledge graph from Neo4j."""
51+ logger .info ("Loading knowledge graph from Neo4j" )
52+ query = """
53+ CALL db.index.fulltext.queryNodes('search', $filter) yield node as entity, score
54+ OPTIONAL MATCH (entity)-[r]-(other)
55+ RETURN collect(distinct {
56+ name: entity.name,
57+ type: entity.type,
58+ observations: entity.observations
59+ }) as nodes,
60+ collect(distinct {
61+ source: startNode(r).name,
62+ target: endNode(r).name,
63+ relationType: type(r)
64+ }) as relations
65+ """
66+
67+ result = await self .driver .execute_query (query , {"filter" : filter_query }, routing_control = RoutingControl .READ )
68+
69+ if not result .records :
70+ return KnowledgeGraph (entities = [], relations = [])
71+
72+ record = result .records [0 ]
73+ nodes = record .get ('nodes' , list ())
74+ rels = record .get ('relations' , list ())
75+
76+ entities = [
77+ Entity (
78+ name = node ['name' ],
79+ type = node ['type' ],
80+ observations = node .get ('observations' , list ())
81+ )
82+ for node in nodes if node .get ('name' )
83+ ]
84+
85+ relations = [
86+ Relation (
87+ source = rel ['source' ],
88+ target = rel ['target' ],
89+ relationType = rel ['relationType' ]
90+ )
91+ for rel in rels if rel .get ('relationType' )
92+ ]
93+
94+ logger .debug (f"Loaded entities: { entities } " )
95+ logger .debug (f"Loaded relations: { relations } " )
96+
97+ return KnowledgeGraph (entities = entities , relations = relations )
98+
99+ async def create_entities (self , entities : List [Entity ]) -> List [Entity ]:
100+ """Create multiple new entities in the knowledge graph."""
101+ logger .info (f"Creating { len (entities )} entities" )
102+ for entity in entities :
103+ query = f"""
104+ WITH $entity as entity
105+ MERGE (e:Memory {{ name: entity.name }})
106+ SET e += entity {{ .type, .observations }}
107+ SET e:{ entity .type }
108+ """
109+ await self .driver .execute_query (query , {"entity" : entity .model_dump ()}, routing_control = RoutingControl .WRITE )
110+
111+ return entities
112+
113+ async def create_relations (self , relations : List [Relation ]) -> List [Relation ]:
114+ """Create multiple new relations between entities."""
115+ logger .info (f"Creating { len (relations )} relations" )
116+ for relation in relations :
117+ query = f"""
118+ WITH $relation as relation
119+ MATCH (from:Memory),(to:Memory)
120+ WHERE from.name = relation.source
121+ AND to.name = relation.target
122+ MERGE (from)-[r:{ relation .relationType } ]->(to)
123+ """
124+
125+ await self .driver .execute_query (
126+ query ,
127+ {"relation" : relation .model_dump ()},
128+ routing_control = RoutingControl .WRITE
129+ )
130+
131+ return relations
132+
133+ async def add_observations (self , observations : List [ObservationAddition ]) -> List [Dict [str , Any ]]:
134+ """Add new observations to existing entities."""
135+ logger .info (f"Adding observations to { len (observations )} entities" )
136+ query = """
137+ UNWIND $observations as obs
138+ MATCH (e:Memory { name: obs.entityName })
139+ WITH e, [o in obs.observations WHERE NOT o IN e.observations] as new
140+ SET e.observations = coalesce(e.observations,[]) + new
141+ RETURN e.name as name, new
142+ """
143+
144+ result = await self .driver .execute_query (
145+ query ,
146+ {"observations" : [obs .model_dump () for obs in observations ]},
147+ routing_control = RoutingControl .WRITE
148+ )
149+
150+ results = [{"entityName" : record .get ("name" ), "addedObservations" : record .get ("new" )} for record in result .records ]
151+ return results
152+
153+ async def delete_entities (self , entity_names : List [str ]) -> None :
154+ """Delete multiple entities and their associated relations."""
155+ logger .info (f"Deleting { len (entity_names )} entities" )
156+ query = """
157+ UNWIND $entities as name
158+ MATCH (e:Memory { name: name })
159+ DETACH DELETE e
160+ """
161+
162+ await self .driver .execute_query (query , {"entities" : entity_names }, routing_control = RoutingControl .WRITE )
163+ logger .info (f"Successfully deleted { len (entity_names )} entities" )
164+
165+ async def delete_observations (self , deletions : List [ObservationDeletion ]) -> None :
166+ """Delete specific observations from entities."""
167+ logger .info (f"Deleting observations from { len (deletions )} entities" )
168+ query = """
169+ UNWIND $deletions as d
170+ MATCH (e:Memory { name: d.entityName })
171+ SET e.observations = [o in coalesce(e.observations,[]) WHERE NOT o IN d.observations]
172+ """
173+ await self .driver .execute_query (
174+ query ,
175+ {"deletions" : [deletion .model_dump () for deletion in deletions ]},
176+ routing_control = RoutingControl .WRITE
177+ )
178+ logger .info (f"Successfully deleted observations from { len (deletions )} entities" )
179+
180+ async def delete_relations (self , relations : List [Relation ]) -> None :
181+ """Delete multiple relations from the graph."""
182+ logger .info (f"Deleting { len (relations )} relations" )
183+ for relation in relations :
184+ query = f"""
185+ WITH $relation as relation
186+ MATCH (source:Memory)-[r:{ relation .relationType } ]->(target:Memory)
187+ WHERE source.name = relation.source
188+ AND target.name = relation.target
189+ DELETE r
190+ """
191+ await self .driver .execute_query (
192+ query ,
193+ {"relation" : relation .model_dump ()},
194+ routing_control = RoutingControl .WRITE
195+ )
196+ logger .info (f"Successfully deleted { len (relations )} relations" )
197+
198+ async def read_graph (self ) -> KnowledgeGraph :
199+ """Read the entire knowledge graph."""
200+ return await self .load_graph ()
201+
202+ async def search_memories (self , query : str ) -> KnowledgeGraph :
203+ """Search for memories based on a query with Fulltext Search."""
204+ logger .info (f"Searching for memories with query: '{ query } '" )
205+ return await self .load_graph (query )
206+
207+ async def find_memories_by_name (self , names : List [str ]) -> KnowledgeGraph :
208+ """Find specific memories by their names. This does not use fulltext search."""
209+ logger .info (f"Finding { len (names )} memories by name" )
210+ query = """
211+ MATCH (e:Memory)
212+ WHERE e.name IN $names
213+ RETURN e.name as name,
214+ e.type as type,
215+ e.observations as observations
216+ """
217+ result_nodes = await self .driver .execute_query (query , {"names" : names }, routing_control = RoutingControl .READ )
218+ entities : list [Entity ] = list ()
219+ for record in result_nodes .records :
220+ entities .append (Entity (
221+ name = record ['name' ],
222+ type = record ['type' ],
223+ observations = record .get ('observations' , list ())
224+ ))
225+
226+ # Get relations for found entities
227+ relations : list [Relation ] = list ()
228+ if entities :
229+ query = """
230+ MATCH (source:Memory)-[r]->(target:Memory)
231+ WHERE source.name IN $names OR target.name IN $names
232+ RETURN source.name as source,
233+ target.name as target,
234+ type(r) as relationType
235+ """
236+ result_relations = await self .driver .execute_query (query , {"names" : names }, routing_control = RoutingControl .READ )
237+ for record in result_relations .records :
238+ relations .append (Relation (
239+ source = record ["source" ],
240+ target = record ["target" ],
241+ relationType = record ["relationType" ]
242+ ))
243+
244+ logger .info (f"Found { len (entities )} entities and { len (relations )} relations" )
245+ return KnowledgeGraph (entities = entities , relations = relations )
0 commit comments