Skip to content

Commit 4abb764

Browse files
authored
Merge branch 'dev' into dev
2 parents 32665dd + 013d0ad commit 4abb764

File tree

14 files changed

+426
-552
lines changed

14 files changed

+426
-552
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ def _format_datetime(value: str | datetime) -> str:
4141
return str(value)
4242

4343

44+
def _normalize_datetime(val):
45+
"""
46+
Normalize datetime to ISO 8601 UTC string with +00:00.
47+
- If val is datetime object -> keep isoformat() (Neo4j)
48+
- If val is string without timezone -> append +00:00 (Nebula)
49+
- Otherwise just str()
50+
"""
51+
if hasattr(val, "isoformat"):
52+
return val.isoformat()
53+
if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
54+
return val + "+08:00"
55+
return str(val)
56+
57+
4458
class SessionPoolError(Exception):
4559
pass
4660

@@ -62,6 +76,7 @@ def __init__(
6276
self.hosts = hosts
6377
self.user = user
6478
self.password = password
79+
self.minsize = minsize
6580
self.maxsize = maxsize
6681
self.pool = Queue(maxsize)
6782
self.lock = Lock()
@@ -79,13 +94,13 @@ def _create_and_add_client(self):
7994
self.clients.append(client)
8095

8196
def get_client(self, timeout: float = 5.0):
82-
from nebulagraph_python import NebulaClient
83-
8497
try:
8598
return self.pool.get(timeout=timeout)
8699
except Empty:
87100
with self.lock:
88101
if len(self.clients) < self.maxsize:
102+
from nebulagraph_python import NebulaClient
103+
89104
client = NebulaClient(self.hosts, self.user, self.password)
90105
self.clients.append(client)
91106
return client
@@ -120,6 +135,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
120135

121136
return _ClientContext(self)
122137

138+
def reset_pool(self):
139+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
140+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
141+
with self.lock:
142+
for client in self.clients:
143+
try:
144+
client.close()
145+
except Exception:
146+
logger.error("Fail to close!!!")
147+
self.clients.clear()
148+
while not self.pool.empty():
149+
try:
150+
self.pool.get_nowait()
151+
except Empty:
152+
break
153+
for _ in range(self.minsize):
154+
self._create_and_add_client()
155+
logger.info("[Pool] Pool has been reset successfully.")
156+
123157

124158
class NebulaGraphDB(BaseGraphDB):
125159
"""
@@ -181,12 +215,27 @@ def __init__(self, config: NebulaGraphDBConfig):
181215

182216
def execute_query(self, gql: str, timeout: float = 5.0, auto_set_db: bool = True):
183217
with self.pool.get() as client:
184-
if auto_set_db and self.db_name:
185-
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
186218
try:
219+
if auto_set_db and self.db_name:
220+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
187221
return client.execute(gql, timeout=timeout)
188-
except Exception:
222+
except Exception as e:
189223
logger.error(f"Fail to run gql {gql} trace: {traceback.format_exc()}")
224+
if "Session not found" in str(e):
225+
logger.warning("[execute_query] Session expired, replacing client.")
226+
try:
227+
client.close()
228+
except Exception:
229+
logger.error("Fail to close!!!!!")
230+
finally:
231+
if client in self.pool.clients:
232+
self.pool.clients.remove(client)
233+
from nebulagraph_python import NebulaClient
234+
235+
new_client = NebulaClient(self.pool.hosts, self.pool.user, self.pool.password)
236+
self.pool.clients.append(new_client)
237+
return new_client.execute(gql, timeout=timeout)
238+
raise
190239

191240
def close(self):
192241
self.pool.close()
@@ -923,9 +972,11 @@ def clear(self) -> None:
923972
except Exception as e:
924973
logger.error(f"[ERROR] Failed to clear database: {e}")
925974

926-
def export_graph(self) -> dict[str, Any]:
975+
def export_graph(self, include_embedding: bool = False) -> dict[str, Any]:
927976
"""
928977
Export all graph nodes and edges in a structured form.
978+
Args:
979+
include_embedding (bool): Whether to include the large embedding field.
929980
930981
Returns:
931982
{
@@ -942,12 +993,41 @@ def export_graph(self) -> dict[str, Any]:
942993
edge_query += f' WHERE r.user_name = "{username}"'
943994

944995
try:
945-
full_node_query = f"{node_query} RETURN n"
946-
node_result = self.execute_query(full_node_query)
996+
if include_embedding:
997+
return_fields = "n"
998+
else:
999+
return_fields = ",".join(
1000+
[
1001+
"n.id AS id",
1002+
"n.memory AS memory",
1003+
"n.user_name AS user_name",
1004+
"n.user_id AS user_id",
1005+
"n.session_id AS session_id",
1006+
"n.status AS status",
1007+
"n.key AS key",
1008+
"n.confidence AS confidence",
1009+
"n.tags AS tags",
1010+
"n.created_at AS created_at",
1011+
"n.updated_at AS updated_at",
1012+
"n.memory_type AS memory_type",
1013+
"n.sources AS sources",
1014+
"n.source AS source",
1015+
"n.node_type AS node_type",
1016+
"n.visibility AS visibility",
1017+
"n.usage AS usage",
1018+
"n.background AS background",
1019+
]
1020+
)
1021+
1022+
full_node_query = f"{node_query} RETURN {return_fields}"
1023+
node_result = self.execute_query(full_node_query, timeout=20)
9471024
nodes = []
1025+
logger.debug(f"Debugging: {node_result}")
9481026
for row in node_result:
949-
node_wrapper = row.values()[0].as_node()
950-
props = node_wrapper.get_properties()
1027+
if include_embedding:
1028+
props = row.values()[0].as_node().get_properties()
1029+
else:
1030+
props = {k: v.value for k, v in row.items()}
9511031

9521032
node = self._parse_node(props)
9531033
nodes.append(node)
@@ -956,7 +1036,7 @@ def export_graph(self) -> dict[str, Any]:
9561036

9571037
try:
9581038
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
959-
edge_result = self.execute_query(full_edge_query)
1039+
edge_result = self.execute_query(full_edge_query, timeout=20)
9601040
edges = [
9611041
{
9621042
"source": row.values()[0].value,
@@ -1023,6 +1103,7 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
10231103
MATCH (n@Memory)
10241104
{where_clause}
10251105
RETURN n
1106+
LIMIT 100
10261107
"""
10271108
nodes = []
10281109
try:
@@ -1065,7 +1146,7 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
10651146
node_props = rec["n"].as_node().get_properties()
10661147
candidates.append(self._parse_node(node_props))
10671148
except Exception as e:
1068-
logger.error(f"Failed : {e}")
1149+
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
10691150
return candidates
10701151

10711152
def drop_database(self) -> None:
@@ -1318,15 +1399,17 @@ def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
13181399
parsed = {k: self._parse_value(v) for k, v in props.items()}
13191400

13201401
for tf in ("created_at", "updated_at"):
1321-
if tf in parsed and hasattr(parsed[tf], "isoformat"):
1322-
parsed[tf] = parsed[tf].isoformat()
1402+
if tf in parsed and parsed[tf] is not None:
1403+
parsed[tf] = _normalize_datetime(parsed[tf])
13231404

13241405
node_id = parsed.pop("id")
13251406
memory = parsed.pop("memory", "")
13261407
parsed.pop("user_name", None)
13271408
metadata = parsed
13281409
metadata["type"] = metadata.pop("node_type")
1329-
metadata["embedding"] = metadata.pop(self.dim_field)
1410+
1411+
if self.dim_field in metadata:
1412+
metadata["embedding"] = metadata.pop(self.dim_field)
13301413

13311414
return {"id": node_id, "memory": memory, "metadata": metadata}
13321415

src/memos/mem_os/utils/format_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,23 @@ def convert_graph_to_tree_forworkmem(
570570
else:
571571
other_roots.append(root_id)
572572

573-
def build_tree(node_id: str) -> dict[str, Any]:
574-
"""Recursively build tree structure"""
573+
def build_tree(node_id: str, visited=None) -> dict[str, Any] | None:
574+
"""Recursively build tree structure with cycle detection"""
575+
if visited is None:
576+
visited = set()
577+
578+
if node_id in visited:
579+
logger.warning(f"[build_tree] Detected cycle at node {node_id}, skipping.")
580+
return None
581+
visited.add(node_id)
582+
575583
if node_id not in node_map:
576584
return None
577585

578586
children_ids = children_map.get(node_id, [])
579587
children = []
580588
for child_id in children_ids:
581-
child_tree = build_tree(child_id)
589+
child_tree = build_tree(child_id, visited)
582590
if child_tree:
583591
children.append(child_tree)
584592

src/memos/memories/textual/tree_text_memory/organize/conflict.py renamed to src/memos/memories/textual/tree_text_memory/organize/handler.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import re
3-
43
from datetime import datetime
54

65
from dateutil import parser
@@ -11,82 +10,68 @@
1110
from memos.log import get_logger
1211
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
1312
from memos.templates.tree_reorganize_prompts import (
14-
CONFLICT_DETECTOR_PROMPT,
15-
CONFLICT_RESOLVER_PROMPT,
13+
MEMORY_RELATION_DETECTOR_PROMPT,
14+
MEMORY_RELATION_RESOLVER_PROMPT,
1615
)
1716

18-
1917
logger = get_logger(__name__)
2018

2119

22-
class ConflictHandler:
20+
class NodeHandler:
2321
EMBEDDING_THRESHOLD: float = 0.8 # Threshold for embedding similarity to consider conflict
2422

2523
def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedder):
2624
self.graph_store = graph_store
2725
self.llm = llm
2826
self.embedder = embedder
2927

30-
def detect(
31-
self, memory: TextualMemoryItem, top_k: int = 5, scope: str | None = None
32-
) -> list[tuple[TextualMemoryItem, TextualMemoryItem]]:
33-
"""
34-
Detect conflicts by finding the most similar items in the graph database based on embedding, then use LLM to judge conflict.
35-
Args:
36-
memory: The memory item (should have an embedding attribute or field).
37-
top_k: Number of top similar nodes to retrieve.
38-
scope: Optional memory type filter.
39-
Returns:
40-
List of conflict pairs (each pair is a tuple: (memory, candidate)).
41-
"""
28+
def detect(self, memory, top_k: int = 5, scope=None):
4229
# 1. Search for similar memories based on embedding
4330
embedding = memory.metadata.embedding
4431
embedding_candidates_info = self.graph_store.search_by_embedding(
45-
embedding, top_k=top_k, scope=scope
32+
embedding, top_k=top_k, scope=scope, threshold=self.EMBEDDING_THRESHOLD
4633
)
4734
# 2. Filter based on similarity threshold
4835
embedding_candidates_ids = [
49-
info["id"]
50-
for info in embedding_candidates_info
51-
if info["score"] >= self.EMBEDDING_THRESHOLD and info["id"] != memory.id
36+
info["id"] for info in embedding_candidates_info if info["id"] != memory.id
5237
]
5338
# 3. Judge conflicts using LLM
5439
embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids)
55-
conflict_pairs = []
40+
detected_relationships = []
5641
for embedding_candidate in embedding_candidates:
5742
embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate)
5843
prompt = [
59-
{
60-
"role": "system",
61-
"content": "You are a conflict detector for memory items.",
62-
},
6344
{
6445
"role": "user",
65-
"content": CONFLICT_DETECTOR_PROMPT.format(
66-
statement_1=memory.memory,
67-
statement_2=embedding_candidate.memory,
46+
"content": MEMORY_RELATION_DETECTOR_PROMPT.format(
47+
statement_1=memory.memory, statement_2=embedding_candidate.memory
6848
),
69-
},
49+
}
7050
]
7151
result = self.llm.generate(prompt).strip()
72-
if "yes" in result.lower():
73-
conflict_pairs.append([memory, embedding_candidate])
74-
if len(conflict_pairs):
75-
conflict_text = "\n".join(
76-
f'"{pair[0].memory!s}" <==CONFLICT==> "{pair[1].memory!s}"'
77-
for pair in conflict_pairs
78-
)
79-
logger.warning(
80-
f"Detected {len(conflict_pairs)} conflicts for memory {memory.id}\n {conflict_text}"
81-
)
82-
return conflict_pairs
52+
if result == "contradictory":
53+
logger.warning(
54+
f'detected "{memory.memory}" <==CONFLICT==> "{embedding_candidate.memory}"'
55+
)
56+
detected_relationships.append([memory, embedding_candidate, "contradictory"])
57+
elif result == "redundant":
58+
logger.warning(
59+
f'detected "{memory.memory}" <==REDUNDANT==> "{embedding_candidate.memory}"'
60+
)
61+
detected_relationships.append([memory, embedding_candidate, "redundant"])
62+
elif result == "independent":
63+
pass
64+
else:
65+
pass
66+
return detected_relationships
8367

84-
def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) -> None:
68+
def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem, relation) -> None:
8569
"""
8670
Resolve detected conflicts between two memory items using LLM fusion.
8771
Args:
8872
memory_a: The first conflicting memory item.
8973
memory_b: The second conflicting memory item.
74+
relation: relation
9075
Returns:
9176
A fused TextualMemoryItem representing the resolved memory.
9277
"""
@@ -96,13 +81,10 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) -> N
9681
metadata_1 = memory_a.metadata.model_dump_json(include=metadata_for_resolve)
9782
metadata_2 = memory_b.metadata.model_dump_json(include=metadata_for_resolve)
9883
prompt = [
99-
{
100-
"role": "system",
101-
"content": "",
102-
},
10384
{
10485
"role": "user",
105-
"content": CONFLICT_RESOLVER_PROMPT.format(
86+
"content": MEMORY_RELATION_RESOLVER_PROMPT.format(
87+
relation=relation,
10688
statement_1=memory_a.memory,
10789
metadata_1=metadata_1,
10890
statement_2=memory_b.memory,
@@ -119,7 +101,7 @@ def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) -> N
119101
# —————— 2.1 Can't resolve conflict, hard update by comparing timestamp ————
120102
if len(answer) <= 10 and "no" in answer.lower():
121103
logger.warning(
122-
f"Conflict between {memory_a.id} and {memory_b.id} could not be resolved. "
104+
f"{relation} between {memory_a.id} and {memory_b.id} could not be resolved. "
123105
)
124106
self._hard_update(memory_a, memory_b)
125107
# —————— 2.2 Conflict resolved, update metadata and memory ————

0 commit comments

Comments
 (0)