Skip to content

Commit ea8e631

Browse files
authored
Fix/remove bug (#356)
* fix: nebula search bug * fix: nebula search bug * fix: auto create bug * feat: add single-db-only assertion * feat: make count_nodes support optional memory_type filtering * fix: dim_field when filter non-embedding nodes * feat: add optional whether include embedding when export graph * fix[WIP]: remove oldest memory update * feat: modify nebula search embedding efficiency * fix: modify nebula remove old memory
1 parent 227b8ea commit ea8e631

File tree

3 files changed

+89
-46
lines changed

3 files changed

+89
-46
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N
188188
client = cls._CLIENT_CACHE.get(key)
189189
if client is None:
190190
# Connection setting
191+
192+
tmp_client = NebulaClient(
193+
hosts=cfg.uri,
194+
username=cfg.user,
195+
password=cfg.password,
196+
session_config=SessionConfig(graph=None),
197+
session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
198+
)
199+
try:
200+
cls._ensure_space_exists(tmp_client, cfg)
201+
finally:
202+
tmp_client.close()
203+
191204
conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
192205
if conn_conf is None:
193206
conn_conf = ConnectionConfig.from_defults(
@@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig):
318331
}
319332
"""
320333

334+
assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
321335
self.config = config
322336
self.db_name = config.space
323337
self.user_name = config.user_name
@@ -429,15 +443,21 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
429443
if not self.config.use_multi_db and self.config.user_name:
430444
optional_condition = f"AND n.user_name = '{self.config.user_name}'"
431445

432-
query = f"""
433-
MATCH (n@Memory)
434-
WHERE n.memory_type = '{memory_type}'
435-
{optional_condition}
436-
ORDER BY n.updated_at DESC
437-
OFFSET {keep_latest}
438-
DETACH DELETE n
439-
"""
440-
self.execute_query(query)
446+
count = self.count_nodes(memory_type)
447+
448+
if count > keep_latest:
449+
delete_query = f"""
450+
MATCH (n@Memory)
451+
WHERE n.memory_type = '{memory_type}'
452+
{optional_condition}
453+
ORDER BY n.updated_at DESC
454+
OFFSET {keep_latest}
455+
DETACH DELETE n
456+
"""
457+
try:
458+
self.execute_query(delete_query)
459+
except Exception as e:
460+
logger.warning(f"Delete old mem error: {e}")
441461

442462
@timed
443463
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
@@ -597,14 +617,19 @@ def get_memory_count(self, memory_type: str) -> int:
597617
return -1
598618

599619
@timed
600-
def count_nodes(self, scope: str) -> int:
601-
query = f"""
602-
MATCH (n@Memory)
603-
WHERE n.memory_type = "{scope}"
604-
"""
620+
def count_nodes(self, scope: str | None = None) -> int:
621+
query = "MATCH (n@Memory)"
622+
conditions = []
623+
624+
if scope:
625+
conditions.append(f'n.memory_type = "{scope}"')
605626
if not self.config.use_multi_db and self.config.user_name:
606627
user_name = self.config.user_name
607-
query += f"\nAND n.user_name = '{user_name}'"
628+
conditions.append(f"n.user_name = '{user_name}'")
629+
630+
if conditions:
631+
query += "\nWHERE " + " AND ".join(conditions)
632+
608633
query += "\nRETURN count(n) AS count"
609634

610635
result = self.execute_query(query)
@@ -985,8 +1010,7 @@ def search_by_embedding(
9851010
dim = len(vector)
9861011
vector_str = ",".join(f"{float(x)}" for x in vector)
9871012
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
988-
989-
where_clauses = []
1013+
where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
9901014
if scope:
9911015
where_clauses.append(f'n.memory_type = "{scope}"')
9921016
if status:
@@ -1008,15 +1032,12 @@ def search_by_embedding(
10081032
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
10091033

10101034
gql = f"""
1011-
MATCH (n@Memory)
1035+
let a = {gql_vector}
1036+
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
10121037
{where_clause}
1013-
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
1014-
APPROXIMATE
1038+
ORDER BY inner_product(n.{self.dim_field}, a) DESC
10151039
LIMIT {top_k}
1016-
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
1017-
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
1018-
"""
1019-
1040+
RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
10201041
try:
10211042
result = self.execute_query(gql)
10221043
except Exception as e:
@@ -1471,6 +1492,25 @@ def merge_nodes(self, id1: str, id2: str) -> str:
14711492
"""
14721493
raise NotImplementedError
14731494

1495+
@classmethod
1496+
def _ensure_space_exists(cls, tmp_client, cfg):
1497+
"""Lightweight check to ensure target graph (space) exists."""
1498+
db_name = getattr(cfg, "space", None)
1499+
if not db_name:
1500+
logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
1501+
return
1502+
1503+
try:
1504+
res = tmp_client.execute("SHOW GRAPHS;")
1505+
existing = {row.values()[0].as_string() for row in res}
1506+
if db_name not in existing:
1507+
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;")
1508+
logger.info(f"✅ Graph `{db_name}` created before session binding.")
1509+
else:
1510+
logger.debug(f"Graph `{db_name}` already exists.")
1511+
except Exception:
1512+
logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
1513+
14741514
@timed
14751515
def _ensure_database_exists(self):
14761516
graph_type_name = "MemOSBgeM3Type"

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,10 @@ def load(self, dir: str) -> None:
326326
except Exception as e:
327327
logger.error(f"An error occurred while loading memories: {e}")
328328

329-
def dump(self, dir: str) -> None:
329+
def dump(self, dir: str, include_embedding: bool = False) -> None:
330330
"""Dump memories to os.path.join(dir, self.config.memory_filename)"""
331331
try:
332-
json_memories = self.graph_store.export_graph()
332+
json_memories = self.graph_store.export_graph(include_embedding=include_embedding)
333333

334334
os.makedirs(dir, exist_ok=True)
335335
memory_file = os.path.join(dir, self.config.memory_filename)

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,30 +67,33 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]:
6767
except Exception as e:
6868
logger.exception("Memory processing error: ", exc_info=e)
6969

70-
try:
71-
self.graph_store.remove_oldest_memory(
72-
memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
73-
)
74-
except Exception:
75-
logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}")
76-
77-
try:
78-
self.graph_store.remove_oldest_memory(
79-
memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
80-
)
81-
except Exception:
82-
logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}")
83-
84-
try:
85-
self.graph_store.remove_oldest_memory(
86-
memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
87-
)
88-
except Exception:
89-
logger.warning(f"Remove UserMemory error: {traceback.format_exc()}")
70+
# Only clean up if we're close to or over the limit
71+
self._cleanup_memories_if_needed()
9072

9173
self._refresh_memory_size()
9274
return added_ids
9375

76+
def _cleanup_memories_if_needed(self) -> None:
77+
"""
78+
Only clean up memories if we're close to or over the limit.
79+
This reduces unnecessary database operations.
80+
"""
81+
cleanup_threshold = 0.8 # Clean up when 80% full
82+
83+
for memory_type, limit in self.memory_size.items():
84+
current_count = self.current_memory_size.get(memory_type, 0)
85+
threshold = int(limit * cleanup_threshold)
86+
87+
# Only clean up if we're at or above the threshold
88+
if current_count >= threshold:
89+
try:
90+
self.graph_store.remove_oldest_memory(
91+
memory_type=memory_type, keep_latest=limit
92+
)
93+
logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}")
94+
except Exception:
95+
logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}")
96+
9497
def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
9598
"""
9699
Replace WorkingMemory

0 commit comments

Comments
 (0)