Skip to content

Commit 9f09d34

Browse files
committed
feat: add structure reorganizer and conflict resolver
1 parent 2e6f6e8 commit 9f09d34

File tree

8 files changed

+887
-10
lines changed

8 files changed

+887
-10
lines changed

examples/core_memories/tree_textual_memory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ def embed_memory_item(memory: str) -> list[float]:
222222
print(f"{i}'th similar result is: " + str(r["memory"]))
223223
print(f"Successfully search {len(results)} memories")
224224

225+
# close the synchronous thread in memory manager
226+
my_tree_textual_memory.memory_manager.close()
227+
225228

226229
# my_tree_textual_memory.dump
227230
my_tree_textual_memory.dump("tmp/my_tree_textual_memory")

src/memos/graph_dbs/item.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import uuid
2+
3+
from typing import Any, Literal
4+
5+
from pydantic import BaseModel, ConfigDict, Field, field_validator
6+
7+
from memos.memories.textual.item import TextualMemoryItem
8+
9+
10+
class GraphDBNode(TextualMemoryItem):
11+
pass
12+
13+
14+
class GraphDBEdge(BaseModel):
15+
"""Represents an edge in a graph database (corresponds to Neo4j relationship)."""
16+
17+
id: str = Field(
18+
default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for the edge"
19+
)
20+
source: str = Field(..., description="Source node ID")
21+
target: str = Field(..., description="Target node ID")
22+
type: Literal["RELATED", "PARENT"] = Field(
23+
..., description="Relationship type (must be one of 'RELATED', 'PARENT')"
24+
)
25+
properties: dict[str, Any] | None = Field(
26+
default=None, description="Additional properties for the edge"
27+
)
28+
29+
model_config = ConfigDict(extra="forbid")
30+
31+
@field_validator("id")
32+
@classmethod
33+
def validate_id(cls, v):
34+
"""Validate that ID is a valid UUID."""
35+
if not isinstance(v, str) or not uuid.UUID(v, version=4):
36+
raise ValueError("ID must be a valid UUID string")
37+
return v
38+
39+
@classmethod
40+
def from_dict(cls, data: dict[str, Any]) -> "GraphDBEdge":
41+
"""Create GraphDBEdge from dictionary."""
42+
return cls(**data)
43+
44+
def to_dict(self) -> dict[str, Any]:
45+
"""Convert to dictionary format."""
46+
return self.model_dump(exclude_none=True)

src/memos/graph_dbs/neo4j.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ def get_memory_count(self, memory_type: str) -> int:
9292
result = session.run(query, memory_type=memory_type)
9393
return result.single()["count"]
9494

95+
def count_nodes(self, scope: str) -> int:
96+
query = """
97+
MATCH (n:Memory)
98+
WHERE n.memory_type = $scope
99+
RETURN count(n) AS count
100+
"""
101+
with self.driver.session(database=self.db_name) as session:
102+
result = session.run(query, {"scope": scope}).single()
103+
return result["count"]
104+
95105
def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
96106
"""
97107
Remove all WorkingMemory nodes except the latest `keep_latest` entries.
@@ -730,6 +740,35 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
730740
results = session.run(query, {"scope": scope})
731741
return [_parse_node(dict(record["n"])) for record in results]
732742

743+
def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
744+
"""
745+
Find nodes that are likely candidates for structure optimization:
746+
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
747+
- Plus: the child of any parent node that has exactly one child.
748+
"""
749+
query = """
750+
// Case 1
751+
MATCH (n:Memory)
752+
WHERE n.memory_type = $scope
753+
AND (
754+
NOT (n)--()
755+
OR n.background IS NULL OR n.background = ''
756+
OR size([ (n)-[:PARENT]->() | 1 ]) = 1
757+
)
758+
RETURN n.id AS id, n AS node
759+
UNION
760+
// Case 2
761+
MATCH (p:Memory)-[:PARENT]->(c:Memory)
762+
WHERE p.memory_type = $scope
763+
AND size([ (p)-[:PARENT]->() | 1 ]) = 1
764+
RETURN c.id AS id, c AS node
765+
"""
766+
767+
with self.driver.session(database=self.db_name) as session:
768+
results = session.run(query, {"scope": scope})
769+
return [_parse_node({"id": record["id"], **dict(record["node"])})
770+
for record in results]
771+
733772
def drop_database(self) -> None:
734773
"""
735774
Permanently delete the entire database this instance is using.

src/memos/memories/textual/tree.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def __init__(self, config: TreeTextMemoryConfig):
3232
self.dispatcher_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.dispatcher_llm)
3333
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
3434
self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
35-
self.memory_manager: MemoryManager = MemoryManager(self.graph_store, self.embedder)
35+
self.memory_manager: MemoryManager = MemoryManager(
36+
self.graph_store, self.embedder, self.extractor_llm
37+
)
3638

3739
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> None:
3840
"""Add memories.
@@ -286,4 +288,4 @@ def _cleanup_old_backups(root_dir: Path, keep_last_n: int) -> None:
286288
shutil.rmtree(old_dir)
287289
logger.info(f"Deleted old backup directory: {old_dir}")
288290
except Exception as e:
289-
logger.warning(f"Failed to delete backup {old_dir}: {e}")
291+
logger.warning(f"Failed to delete backup {old_dir}: {e}")
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import json
2+
import re
3+
4+
from datetime import datetime
5+
6+
from memos.embedders.base import BaseEmbedder
7+
from memos.graph_dbs.neo4j import Neo4jGraphDB
8+
from memos.llms.base import BaseLLM
9+
from memos.log import get_logger
10+
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
11+
12+
13+
logger = get_logger(__name__)
14+
15+
16+
class ConflictDetector:
17+
EMBEDDING_THRESHOLD: float = 0.8 # Threshold for embedding similarity to consider conflict
18+
19+
def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM):
20+
self.graph_store = graph_store
21+
self.llm = llm
22+
23+
def detect(
24+
self, memory: TextualMemoryItem, top_k: int = 5, scope: str | None = None
25+
) -> list[tuple[TextualMemoryItem, TextualMemoryItem]]:
26+
"""
27+
Detect conflicts by finding the most similar items in the graph database based on embedding, then use LLM to judge conflict.
28+
Args:
29+
memory: The memory item (should have an embedding attribute or field).
30+
top_k: Number of top similar nodes to retrieve.
31+
scope: Optional memory type filter.
32+
Returns:
33+
List of conflict pairs (each pair is a tuple: (memory, candidate)).
34+
"""
35+
# 1. Search for similar memories based on embedding
36+
embedding = memory.metadata.embedding
37+
embedding_candidates_info = self.graph_store.search_by_embedding(
38+
embedding, top_k=top_k, scope=scope
39+
)
40+
# 2. Filter based on similarity threshold
41+
embedding_candidates_ids = [
42+
info["id"]
43+
for info in embedding_candidates_info
44+
if info["score"] >= self.EMBEDDING_THRESHOLD and info["id"] != memory.id
45+
]
46+
# 3. Judge conflicts using LLM
47+
embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids)
48+
conflict_pairs = []
49+
for embedding_candidate in embedding_candidates:
50+
embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate)
51+
prompt = [
52+
{"role": "system", "content": "You are a conflict detector for memory items."},
53+
{
54+
"role": "user",
55+
"content": f"""
56+
You are given two plaintext statements. Determine if these two statements are factually contradictory. Respond with only "yes" if they contradict each other, or "no" if they do not contradict each other. Do not provide any explanation or additional text.
57+
Statement 1: {memory.memory!s}
58+
Statement 2: {embedding_candidate.memory!s}""",
59+
},
60+
]
61+
result = self.llm.generate(prompt).strip().lower()
62+
if "yes" in result.lower():
63+
conflict_pairs.append([memory, embedding_candidate])
64+
if len(conflict_pairs):
65+
conflict_text = "\n".join(
66+
f'"{pair[0].memory!s}" <==CONFLICT==> "{pair[1].memory!s}"'
67+
for pair in conflict_pairs
68+
)
69+
logger.warning(
70+
f"Detected {len(conflict_pairs)} conflicts for memory {memory.id}\n {conflict_text}"
71+
)
72+
for pair in conflict_pairs:
73+
print(pair[0].id, pair[1].id)
74+
return conflict_pairs
75+
76+
77+
class ConflictResolver:
78+
def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedder):
79+
self.graph_store = graph_store
80+
self.llm = llm
81+
self.embedder = embedder
82+
83+
def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) -> None:
84+
"""
85+
Resolve detected conflicts between two memory items using LLM fusion.
86+
Args:
87+
memory_a: The first conflicting memory item.
88+
memory_b: The second conflicting memory item.
89+
Returns:
90+
A fused TextualMemoryItem representing the resolved memory.
91+
"""
92+
93+
# ———————————— 1. LLM generate fused memory ————————————
94+
metadata_for_resolve = ["key", "background", "confidence", "updated_at"]
95+
metadata_1 = memory_a.metadata.model_dump_json(include=metadata_for_resolve)
96+
metadata_2 = memory_b.metadata.model_dump_json(include=metadata_for_resolve)
97+
prompt = [
98+
{
99+
"role": "system",
100+
"content": "",
101+
},
102+
{
103+
"role": "user",
104+
"content": CONFLICT_RESOLVER_PROMPT.format(
105+
statement_1=memory_a.memory,
106+
metadata_1=metadata_1,
107+
statement_2=memory_b.memory,
108+
metadata_2=metadata_2,
109+
),
110+
},
111+
]
112+
response = self.llm.generate(prompt).strip()
113+
114+
# ———————————— 2. Parse the response ————————————
115+
try:
116+
answer = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
117+
answer = answer.group(1).strip()
118+
# —————— 2.1 Can't resolve conflict, hard update by comparing timestamp ————
119+
if len(answer) <= 10 and "no" in answer.lower():
120+
logger.warning(
121+
f"Conflict between {memory_a.id} and {memory_b.id} could not be resolved. "
122+
)
123+
self._hard_update(memory_a, memory_b)
124+
# —————— 2.2 Conflict resolved, update metadata and memory ————
125+
else:
126+
fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata)
127+
merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata)
128+
logger.info(f"Resolved result: {merged_memory}")
129+
self._resolve_in_graph(memory_a, memory_b, merged_memory)
130+
except json.decoder.JSONDecodeError:
131+
logger.error(f"Failed to parse LLM response: {response}")
132+
133+
def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem):
134+
"""
135+
Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata.
136+
"""
137+
time_a = datetime.fromisoformat(memory_a.metadata.updated_at)
138+
time_b = datetime.fromisoformat(memory_b.metadata.updated_at)
139+
140+
newer_mem = memory_a if time_a >= time_b else memory_b
141+
older_mem = memory_b if time_a >= time_b else memory_a
142+
143+
self.graph_store.delete_node(older_mem.id)
144+
logger.warning(
145+
f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>"
146+
)
147+
148+
def _resolve_in_graph(
149+
self,
150+
conflict_a: TextualMemoryItem,
151+
conflict_b: TextualMemoryItem,
152+
merged: TextualMemoryItem,
153+
):
154+
edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY")
155+
edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY")
156+
all_edges = edges_a + edges_b
157+
158+
self.graph_store.add_node(
159+
merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True)
160+
)
161+
162+
for edge in all_edges:
163+
new_from = merged.id if edge["from"] in (conflict_a.id, conflict_b.id) else edge["from"]
164+
new_to = merged.id if edge["to"] in (conflict_a.id, conflict_b.id) else edge["to"]
165+
if new_from == new_to:
166+
continue
167+
# Check if the edge already exists before adding
168+
if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"):
169+
self.graph_store.add_edge(new_from, new_to, edge["type"])
170+
171+
self.graph_store.delete_node(conflict_a.id)
172+
self.graph_store.delete_node(conflict_b.id)
173+
logger.debug(
174+
f"Remove {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}."
175+
)
176+
177+
def _merge_metadata(
178+
self,
179+
memory: str,
180+
metadata_a: TreeNodeTextualMemoryMetadata,
181+
metadata_b: TreeNodeTextualMemoryMetadata,
182+
) -> TreeNodeTextualMemoryMetadata:
183+
metadata_1 = metadata_a.model_dump()
184+
metadata_2 = metadata_b.model_dump()
185+
merged_metadata = {
186+
"sources": (metadata_1["sources"] or []) + (metadata_2["sources"] or []),
187+
"embedding": self.embedder.embed([memory])[0],
188+
"update_at": datetime.now().isoformat(),
189+
"created_at": datetime.now().isoformat(),
190+
}
191+
for key in metadata_1:
192+
if key in merged_metadata:
193+
continue
194+
merged_metadata[key] = (
195+
metadata_1[key] if metadata_1[key] is not None else metadata_2[key]
196+
)
197+
return TreeNodeTextualMemoryMetadata.model_validate(merged_metadata)
198+
199+
200+
CONFLICT_RESOLVER_PROMPT = """You are given two facts that conflict with each other. You are also given some contextual metadata of them. Your task is to analyze the two facts in light of the contextual metadata and try to reconcile them into a single, consistent, non-conflicting fact.
201+
- Don't output any explanation or additional text, just the final reconciled fact, try to be objective and remain independent of the context, don't use pronouns.
202+
- Try to judge facts by using its time, confidence etc.
203+
- Try to retain as much information as possible from the perspective of time.
204+
If the conflict cannot be resolved, output <answer>No</answer>. Otherwise, output the fused, consistent fact in enclosed with <answer></answer> tags.
205+
206+
Output Example 1:
207+
<answer>No</answer>
208+
209+
Output Example 2:
210+
<answer> ... </answer>
211+
212+
Now reconcile the following two facts:
213+
Statement 1: {statement_1}
214+
Metadata 1: {metadata_1}
215+
Statement 2: {statement_2}
216+
Metadata 2: {metadata_2}
217+
"""

0 commit comments

Comments
 (0)