1- import time
2-
31from typing import Any
42
5- from neo4j import GraphDatabase
6- from neo4j .exceptions import ClientError
7-
83from memos .configs .graph_db import Neo4jGraphDBConfig
9- from memos .graph_dbs .neo4j import Neo4jGraphDB , _parse_node
4+ from memos .graph_dbs .neo4j import Neo4jGraphDB , _prepare_node_metadata
105from memos .log import get_logger
6+ from memos .vec_dbs .factory import VecDBFactory
7+ from memos .vec_dbs .item import VecDBItem
118
129
1310logger = get_logger (__name__ )
1411
1512
1613class Neo4jCommunityGraphDB (Neo4jGraphDB ):
17- """Neo4j-based implementation of a graph memory store."""
14+ """
15+ Neo4j Community Edition graph memory store.
16+
17+ Note:
18+ This class avoids Enterprise-only features:
19+ - No multi-database support
20+ - No vector index
21+ - No CREATE DATABASE
22+ """
1823
1924 def __init__ (self , config : Neo4jGraphDBConfig ):
20- """Neo4j-based implementation of a graph memory store.
21-
22- Tenant Modes:
23- - use_multi_db = True:
24- Dedicated Database Mode (Multi-Database Multi-Tenant).
25- Each tenant or logical scope uses a separate Neo4j database.
26- `db_name` is the specific tenant database.
27- `user_name` can be None (optional).
28-
29- - use_multi_db = False:
30- Shared Database Multi-Tenant Mode.
31- All tenants share a single Neo4j database.
32- `db_name` is the shared database.
33- `user_name` is required to isolate each tenant's data at the node level.
34- All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
35- but it will be removed automatically before returning to external consumers.
36- """
25+ assert config .auto_create is False
26+ assert config .use_multi_db is False
27+ # Call parent init
28+ super ().__init__ (config )
3729
38- self .config = config
39- self .driver = GraphDatabase .driver (config .uri , auth = (config .user , config .password ))
40- self .db_name = config .db_name
41- self .user_name = config .user_name
42- self .system_db_name = config .db_name
43- # Create only if not exists
44- self .create_index (dimensions = config .embedding_dimension )
30+ # Init vector database
31+ self .vec_db = VecDBFactory .from_config (config .vec_config )
4532
4633 def create_index (
4734 self ,
@@ -116,6 +103,54 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
116103 with self .driver .session (database = self .db_name ) as session :
117104 session .run (query )
118105
106+ def add_node (self , id : str , memory : str , metadata : dict [str , Any ]) -> None :
107+ # Safely process metadata
108+ metadata = _prepare_node_metadata (metadata )
109+
110+ # Extract required fields
111+ embedding = metadata .pop ("embedding" , None )
112+ if embedding is None :
113+ raise ValueError (f"Missing 'embedding' in metadata for node { id } " )
114+
115+ # Merge node and set metadata
116+ created_at = metadata .pop ("created_at" )
117+ updated_at = metadata .pop ("updated_at" )
118+ vector_sync_status = "success"
119+
120+ try :
121+ # Write to Vector DB
122+ item = VecDBItem (
123+ id = id ,
124+ vector = embedding ,
125+ payload = {
126+ "memory" : memory ,
127+ "metadata" : metadata ,
128+ "vector_sync" : vector_sync_status ,
129+ },
130+ )
131+ self .vec_db .add ([item ])
132+ except Exception as e :
133+ logger .warning (f"[VecDB] Vector insert failed for node { id } : { e } " )
134+ vector_sync_status = "failed"
135+
136+ metadata ["vector_sync" ] = vector_sync_status
137+ query = """
138+ MERGE (n:Memory {id: $id})
139+ SET n.memory = $memory,
140+ n.created_at = datetime($created_at),
141+ n.updated_at = datetime($updated_at),
142+ n += $metadata
143+ """
144+ with self .driver .session (database = self .db_name ) as session :
145+ session .run (
146+ query ,
147+ id = id ,
148+ memory = memory ,
149+ created_at = created_at ,
150+ updated_at = updated_at ,
151+ metadata = metadata ,
152+ )
153+
119154 def get_children_with_embeddings (self , id : str ) -> list [dict [str , Any ]]:
120155 where_user = ""
121156 params = {"id" : id }
@@ -184,8 +219,8 @@ def get_subgraph(
184219 if not centers or centers [0 ] is None :
185220 return {"core_node" : None , "neighbors" : [], "edges" : []}
186221
187- core_node = _parse_node (dict (centers [0 ]))
188- neighbors = [_parse_node (dict (n )) for n in record ["neighbors" ] if n ]
222+ core_node = self . _parse_node (dict (centers [0 ]))
223+ neighbors = [self . _parse_node (dict (n )) for n in record ["neighbors" ] if n ]
189224 edges = []
190225 for rel_chain in record ["rels" ]:
191226 for rel in rel_chain :
@@ -209,32 +244,42 @@ def search_by_embedding(
209244 threshold : float | None = None ,
210245 ) -> list [dict ]:
211246 """
212- Retrieve node IDs based on vector similarity.
247+ Retrieve node IDs based on vector similarity using external vector DB .
213248
214249 Args:
215250 vector (list[float]): The embedding vector representing query semantics.
216251 top_k (int): Number of top similar nodes to retrieve.
217252 scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
218- status (str, optional): Node status filter (e.g., 'active', 'archived').
219- If provided, restricts results to nodes with matching status.
253+ status (str, optional): Node status filter (e.g., 'activated', 'archived').
220254 threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
221255
222256 Returns:
223257 list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
224258
225259 Notes:
226- - This method uses Neo4j native vector indexing to search for similar nodes.
227- - If scope is provided, it restricts results to nodes with matching memory_type.
228- - If 'status' is provided, only nodes with the matching status will be returned.
229- - If threshold is provided, only results with score >= threshold will be returned.
230- - Typical use case: restrict to 'status = activated' to avoid
231- matching archived or merged nodes.
260+ - This method uses an external vector database (not Neo4j) to perform the search.
261+ - If 'scope' is provided, it restricts results to nodes with matching memory_type.
262+ - If 'status' is provided, it further filters nodes by status.
263+ - If 'threshold' is provided, only results with score >= threshold will be returned.
264+ - The returned IDs can be used to fetch full node data from Neo4j if needed.
232265 """
233- # TODO
234- from your_vector_index import vector_index
266+ # Build VecDB filter
267+ vec_filter = {}
268+ if scope :
269+ vec_filter ["metadata.memory_type" ] = scope
270+ if status :
271+ vec_filter ["metadata.status" ] = status
272+ vec_filter ["metadata.vector_sync" ] = "success"
235273
236- results = vector_index .query (vector , top_k = top_k )
237- return [{"id" : item .id , "score" : item .score } for item in results ]
274+ # Perform vector search
275+ results = self .vec_db .search (query_vector = vector , top_k = top_k , filter = vec_filter )
276+
277+ # Filter by threshold
278+ if threshold is not None :
279+ results = [r for r in results if r .score is None or r .score >= threshold ]
280+
281+ # Return consistent format
282+ return [{"id" : r .id , "score" : r .score } for r in results ]
238283
239284 def get_all_memory_items (self , scope : str ) -> list [dict ]:
240285 """
@@ -264,7 +309,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
264309
265310 with self .driver .session (database = self .db_name ) as session :
266311 results = session .run (query , params )
267- return [_parse_node (dict (record ["n" ])) for record in results ]
312+ return [self . _parse_node (dict (record ["n" ])) for record in results ]
268313
269314 def get_structure_optimization_candidates (self , scope : str ) -> list [dict ]:
270315 """
@@ -291,7 +336,9 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
291336
292337 with self .driver .session (database = self .db_name ) as session :
293338 results = session .run (query , params )
294- return [_parse_node ({"id" : record ["id" ], ** dict (record ["node" ])}) for record in results ]
339+ return [
340+ self ._parse_node ({"id" : record ["id" ], ** dict (record ["node" ])}) for record in results
341+ ]
295342
296343 def drop_database (self ) -> None :
297344 """
@@ -303,28 +350,9 @@ def drop_database(self) -> None:
303350 f"Shared Database Multi-Tenant mode"
304351 )
305352
353+ # Avoid enterprise feature
306354 def _ensure_database_exists (self ):
307- try :
308- with self .driver .session (database = "system" ) as session :
309- session .run (f"CREATE DATABASE `{ self .db_name } ` IF NOT EXISTS" )
310- except ClientError as e :
311- if "ExistingDatabaseFound" in str (e ):
312- pass # Ignore, database already exists
313- else :
314- raise
315-
316- # Wait until the database is available
317- for _ in range (10 ):
318- with self .driver .session (database = self .system_db_name ) as session :
319- result = session .run (
320- "SHOW DATABASES YIELD name, currentStatus RETURN name, currentStatus"
321- )
322- status_map = {r ["name" ]: r ["currentStatus" ] for r in result }
323- if self .db_name in status_map and status_map [self .db_name ] == "online" :
324- return
325- time .sleep (1 )
326-
327- raise RuntimeError (f"Database { self .db_name } not ready after waiting." )
355+ pass
328356
329357 def _create_basic_property_indexes (self ) -> None :
330358 """
@@ -375,3 +403,23 @@ def _index_exists(self, index_name: str) -> bool:
375403 if record ["name" ] == index_name :
376404 return True
377405 return False
406+
407+ def _parse_node (self , node_data : dict [str , Any ]) -> dict [str , Any ]:
408+ """Parse Neo4j node and optionally fetch embedding from vector DB."""
409+ node = node_data .copy ()
410+
411+ # Convert Neo4j datetime to string
412+ for time_field in ("created_at" , "updated_at" ):
413+ if time_field in node and hasattr (node [time_field ], "isoformat" ):
414+ node [time_field ] = node [time_field ].isoformat ()
415+ node .pop ("user_name" , None )
416+
417+ new_node = {"id" : node .pop ("id" ), "memory" : node .pop ("memory" , "" ), "metadata" : node }
418+ try :
419+ vec_item = self .vec_db .get_by_id (new_node ["id" ])
420+ if vec_item and vec_item .vector :
421+ new_node ["embedding" ] = vec_item .vector
422+ except Exception as e :
423+ logger .warning (f"Failed to fetch vector for node { new_node ['id' ]} : { e } " )
424+ new_node ["embedding" ] = None
425+ return new_node
0 commit comments