Skip to content

Commit 7094fb2

Browse files
committed
feat: add outer vector db
1 parent ba299d2 commit 7094fb2

File tree

3 files changed

+138
-90
lines changed

3 files changed

+138
-90
lines changed

src/memos/graph_dbs/neo4j.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,6 @@
1414
logger = get_logger(__name__)
1515

1616

17-
def _parse_node(node_data: dict[str, Any]) -> dict[str, Any]:
18-
node = node_data.copy()
19-
20-
# Convert Neo4j datetime to string
21-
for time_field in ("created_at", "updated_at"):
22-
if time_field in node and hasattr(node[time_field], "isoformat"):
23-
node[time_field] = node[time_field].isoformat()
24-
node.pop("user_name", None)
25-
26-
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
27-
28-
2917
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
3018
node_id = item["id"]
3119
memory = item["memory"]
@@ -349,7 +337,7 @@ def get_node(self, id: str) -> dict[str, Any] | None:
349337

350338
with self.driver.session(database=self.db_name) as session:
351339
record = session.run(query, params).single()
352-
return _parse_node(dict(record["n"])) if record else None
340+
return self._parse_node(dict(record["n"])) if record else None
353341

354342
def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
355343
"""
@@ -377,7 +365,7 @@ def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
377365

378366
with self.driver.session(database=self.db_name) as session:
379367
results = session.run(query, params)
380-
return [_parse_node(dict(record["n"])) for record in results]
368+
return [self._parse_node(dict(record["n"])) for record in results]
381369

382370
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
383371
"""
@@ -493,7 +481,7 @@ def get_neighbors_by_tag(
493481

494482
with self.driver.session(database=self.db_name) as session:
495483
result = session.run(query, params)
496-
return [_parse_node(dict(record["n"])) for record in result]
484+
return [self._parse_node(dict(record["n"])) for record in result]
497485

498486
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
499487
where_user = ""
@@ -575,8 +563,8 @@ def get_subgraph(
575563
if not centers or centers[0] is None:
576564
return {"core_node": None, "neighbors": [], "edges": []}
577565

578-
core_node = _parse_node(dict(centers[0]))
579-
neighbors = [_parse_node(dict(n)) for n in record["neighbors"] if n]
566+
core_node = self._parse_node(dict(centers[0]))
567+
neighbors = [self._parse_node(dict(n)) for n in record["neighbors"] if n]
580568
edges = []
581569
for rel_chain in record["rels"]:
582570
for rel in rel_chain:
@@ -859,7 +847,7 @@ def export_graph(self) -> dict[str, Any]:
859847
params["user_name"] = self.config.user_name
860848

861849
node_result = session.run(f"{node_query} RETURN n", params)
862-
nodes = [_parse_node(dict(record["n"])) for record in node_result]
850+
nodes = [self._parse_node(dict(record["n"])) for record in node_result]
863851

864852
# Export edges
865853
edge_result = session.run(
@@ -946,7 +934,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
946934

947935
with self.driver.session(database=self.db_name) as session:
948936
results = session.run(query, params)
949-
return [_parse_node(dict(record["n"])) for record in results]
937+
return [self._parse_node(dict(record["n"])) for record in results]
950938

951939
def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
952940
"""
@@ -973,7 +961,9 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
973961

974962
with self.driver.session(database=self.db_name) as session:
975963
results = session.run(query, params)
976-
return [_parse_node({"id": record["id"], **dict(record["node"])}) for record in results]
964+
return [
965+
self._parse_node({"id": record["id"], **dict(record["node"])}) for record in results
966+
]
977967

978968
def drop_database(self) -> None:
979969
"""
@@ -1094,3 +1084,14 @@ def _index_exists(self, index_name: str) -> bool:
10941084
if record["name"] == index_name:
10951085
return True
10961086
return False
1087+
1088+
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
1089+
node = node_data.copy()
1090+
1091+
# Convert Neo4j datetime to string
1092+
for time_field in ("created_at", "updated_at"):
1093+
if time_field in node and hasattr(node[time_field], "isoformat"):
1094+
node[time_field] = node[time_field].isoformat()
1095+
node.pop("user_name", None)
1096+
1097+
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}

src/memos/graph_dbs/neo4j_community.py

Lines changed: 117 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,34 @@
1-
import time
2-
31
from typing import Any
42

5-
from neo4j import GraphDatabase
6-
from neo4j.exceptions import ClientError
7-
83
from memos.configs.graph_db import Neo4jGraphDBConfig
9-
from memos.graph_dbs.neo4j import Neo4jGraphDB, _parse_node
4+
from memos.graph_dbs.neo4j import Neo4jGraphDB, _prepare_node_metadata
105
from memos.log import get_logger
6+
from memos.vec_dbs.factory import VecDBFactory
7+
from memos.vec_dbs.item import VecDBItem
118

129

1310
logger = get_logger(__name__)
1411

1512

1613
class Neo4jCommunityGraphDB(Neo4jGraphDB):
17-
"""Neo4j-based implementation of a graph memory store."""
14+
"""
15+
Neo4j Community Edition graph memory store.
16+
17+
Note:
18+
This class avoids Enterprise-only features:
19+
- No multi-database support
20+
- No vector index
21+
- No CREATE DATABASE
22+
"""
1823

1924
def __init__(self, config: Neo4jGraphDBConfig):
20-
"""Neo4j-based implementation of a graph memory store.
21-
22-
Tenant Modes:
23-
- use_multi_db = True:
24-
Dedicated Database Mode (Multi-Database Multi-Tenant).
25-
Each tenant or logical scope uses a separate Neo4j database.
26-
`db_name` is the specific tenant database.
27-
`user_name` can be None (optional).
28-
29-
- use_multi_db = False:
30-
Shared Database Multi-Tenant Mode.
31-
All tenants share a single Neo4j database.
32-
`db_name` is the shared database.
33-
`user_name` is required to isolate each tenant's data at the node level.
34-
All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
35-
but it will be removed automatically before returning to external consumers.
36-
"""
25+
assert config.auto_create is False
26+
assert config.use_multi_db is False
27+
# Call parent init
28+
super().__init__(config)
3729

38-
self.config = config
39-
self.driver = GraphDatabase.driver(config.uri, auth=(config.user, config.password))
40-
self.db_name = config.db_name
41-
self.user_name = config.user_name
42-
self.system_db_name = config.db_name
43-
# Create only if not exists
44-
self.create_index(dimensions=config.embedding_dimension)
30+
# Init vector database
31+
self.vec_db = VecDBFactory.from_config(config.vec_config)
4532

4633
def create_index(
4734
self,
@@ -116,6 +103,54 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
116103
with self.driver.session(database=self.db_name) as session:
117104
session.run(query)
118105

106+
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
107+
# Safely process metadata
108+
metadata = _prepare_node_metadata(metadata)
109+
110+
# Extract required fields
111+
embedding = metadata.pop("embedding", None)
112+
if embedding is None:
113+
raise ValueError(f"Missing 'embedding' in metadata for node {id}")
114+
115+
# Merge node and set metadata
116+
created_at = metadata.pop("created_at")
117+
updated_at = metadata.pop("updated_at")
118+
vector_sync_status = "success"
119+
120+
try:
121+
# Write to Vector DB
122+
item = VecDBItem(
123+
id=id,
124+
vector=embedding,
125+
payload={
126+
"memory": memory,
127+
"metadata": metadata,
128+
"vector_sync": vector_sync_status,
129+
},
130+
)
131+
self.vec_db.add([item])
132+
except Exception as e:
133+
logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}")
134+
vector_sync_status = "failed"
135+
136+
metadata["vector_sync"] = vector_sync_status
137+
query = """
138+
MERGE (n:Memory {id: $id})
139+
SET n.memory = $memory,
140+
n.created_at = datetime($created_at),
141+
n.updated_at = datetime($updated_at),
142+
n += $metadata
143+
"""
144+
with self.driver.session(database=self.db_name) as session:
145+
session.run(
146+
query,
147+
id=id,
148+
memory=memory,
149+
created_at=created_at,
150+
updated_at=updated_at,
151+
metadata=metadata,
152+
)
153+
119154
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
120155
where_user = ""
121156
params = {"id": id}
@@ -184,8 +219,8 @@ def get_subgraph(
184219
if not centers or centers[0] is None:
185220
return {"core_node": None, "neighbors": [], "edges": []}
186221

187-
core_node = _parse_node(dict(centers[0]))
188-
neighbors = [_parse_node(dict(n)) for n in record["neighbors"] if n]
222+
core_node = self._parse_node(dict(centers[0]))
223+
neighbors = [self._parse_node(dict(n)) for n in record["neighbors"] if n]
189224
edges = []
190225
for rel_chain in record["rels"]:
191226
for rel in rel_chain:
@@ -209,32 +244,42 @@ def search_by_embedding(
209244
threshold: float | None = None,
210245
) -> list[dict]:
211246
"""
212-
Retrieve node IDs based on vector similarity.
247+
Retrieve node IDs based on vector similarity using external vector DB.
213248
214249
Args:
215250
vector (list[float]): The embedding vector representing query semantics.
216251
top_k (int): Number of top similar nodes to retrieve.
217252
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
218-
status (str, optional): Node status filter (e.g., 'active', 'archived').
219-
If provided, restricts results to nodes with matching status.
253+
status (str, optional): Node status filter (e.g., 'activated', 'archived').
220254
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
221255
222256
Returns:
223257
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
224258
225259
Notes:
226-
- This method uses Neo4j native vector indexing to search for similar nodes.
227-
- If scope is provided, it restricts results to nodes with matching memory_type.
228-
- If 'status' is provided, only nodes with the matching status will be returned.
229-
- If threshold is provided, only results with score >= threshold will be returned.
230-
- Typical use case: restrict to 'status = activated' to avoid
231-
matching archived or merged nodes.
260+
- This method uses an external vector database (not Neo4j) to perform the search.
261+
- If 'scope' is provided, it restricts results to nodes with matching memory_type.
262+
- If 'status' is provided, it further filters nodes by status.
263+
- If 'threshold' is provided, only results with score >= threshold will be returned.
264+
- The returned IDs can be used to fetch full node data from Neo4j if needed.
232265
"""
233-
# TODO
234-
from your_vector_index import vector_index
266+
# Build VecDB filter
267+
vec_filter = {}
268+
if scope:
269+
vec_filter["metadata.memory_type"] = scope
270+
if status:
271+
vec_filter["metadata.status"] = status
272+
vec_filter["metadata.vector_sync"] = "success"
235273

236-
results = vector_index.query(vector, top_k=top_k)
237-
return [{"id": item.id, "score": item.score} for item in results]
274+
# Perform vector search
275+
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
276+
277+
# Filter by threshold
278+
if threshold is not None:
279+
results = [r for r in results if r.score is None or r.score >= threshold]
280+
281+
# Return consistent format
282+
return [{"id": r.id, "score": r.score} for r in results]
238283

239284
def get_all_memory_items(self, scope: str) -> list[dict]:
240285
"""
@@ -264,7 +309,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
264309

265310
with self.driver.session(database=self.db_name) as session:
266311
results = session.run(query, params)
267-
return [_parse_node(dict(record["n"])) for record in results]
312+
return [self._parse_node(dict(record["n"])) for record in results]
268313

269314
def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
270315
"""
@@ -291,7 +336,9 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
291336

292337
with self.driver.session(database=self.db_name) as session:
293338
results = session.run(query, params)
294-
return [_parse_node({"id": record["id"], **dict(record["node"])}) for record in results]
339+
return [
340+
self._parse_node({"id": record["id"], **dict(record["node"])}) for record in results
341+
]
295342

296343
def drop_database(self) -> None:
297344
"""
@@ -303,28 +350,9 @@ def drop_database(self) -> None:
303350
f"Shared Database Multi-Tenant mode"
304351
)
305352

353+
# Avoid enterprise feature
306354
def _ensure_database_exists(self):
307-
try:
308-
with self.driver.session(database="system") as session:
309-
session.run(f"CREATE DATABASE `{self.db_name}` IF NOT EXISTS")
310-
except ClientError as e:
311-
if "ExistingDatabaseFound" in str(e):
312-
pass # Ignore, database already exists
313-
else:
314-
raise
315-
316-
# Wait until the database is available
317-
for _ in range(10):
318-
with self.driver.session(database=self.system_db_name) as session:
319-
result = session.run(
320-
"SHOW DATABASES YIELD name, currentStatus RETURN name, currentStatus"
321-
)
322-
status_map = {r["name"]: r["currentStatus"] for r in result}
323-
if self.db_name in status_map and status_map[self.db_name] == "online":
324-
return
325-
time.sleep(1)
326-
327-
raise RuntimeError(f"Database {self.db_name} not ready after waiting.")
355+
pass
328356

329357
def _create_basic_property_indexes(self) -> None:
330358
"""
@@ -375,3 +403,23 @@ def _index_exists(self, index_name: str) -> bool:
375403
if record["name"] == index_name:
376404
return True
377405
return False
406+
407+
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
408+
"""Parse Neo4j node and optionally fetch embedding from vector DB."""
409+
node = node_data.copy()
410+
411+
# Convert Neo4j datetime to string
412+
for time_field in ("created_at", "updated_at"):
413+
if time_field in node and hasattr(node[time_field], "isoformat"):
414+
node[time_field] = node[time_field].isoformat()
415+
node.pop("user_name", None)
416+
417+
new_node = {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}
418+
try:
419+
vec_item = self.vec_db.get_by_id(new_node["id"])
420+
if vec_item and vec_item.vector:
421+
new_node["embedding"] = vec_item.vector
422+
except Exception as e:
423+
logger.warning(f"Failed to fetch vector for node {new_node['id']}: {e}")
424+
new_node["embedding"] = None
425+
return new_node

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def retrieve(
5656
# Step 3: Merge and deduplicate results
5757
combined = {item.id: item for item in graph_results + vector_results}
5858

59-
# Debug: 打印在 graph_results 中但不在 combined 中的 id
6059
graph_ids = {item.id for item in graph_results}
6160
combined_ids = set(combined.keys())
6261
lost_ids = graph_ids - combined_ids

0 commit comments

Comments
 (0)