|
| 1 | +import json |
| 2 | +import re |
| 3 | + |
| 4 | +from datetime import datetime |
| 5 | + |
| 6 | +from memos.embedders.base import BaseEmbedder |
| 7 | +from memos.graph_dbs.neo4j import Neo4jGraphDB |
| 8 | +from memos.llms.base import BaseLLM |
| 9 | +from memos.log import get_logger |
| 10 | +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata |
| 11 | + |
| 12 | + |
| 13 | +logger = get_logger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class ConflictDetector: |
| 17 | + EMBEDDING_THRESHOLD: float = 0.8 # Threshold for embedding similarity to consider conflict |
| 18 | + |
| 19 | + def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM): |
| 20 | + self.graph_store = graph_store |
| 21 | + self.llm = llm |
| 22 | + |
| 23 | + def detect( |
| 24 | + self, memory: TextualMemoryItem, top_k: int = 5, scope: str | None = None |
| 25 | + ) -> list[tuple[TextualMemoryItem, TextualMemoryItem]]: |
| 26 | + """ |
| 27 | + Detect conflicts by finding the most similar items in the graph database based on embedding, then use LLM to judge conflict. |
| 28 | + Args: |
| 29 | + memory: The memory item (should have an embedding attribute or field). |
| 30 | + top_k: Number of top similar nodes to retrieve. |
| 31 | + scope: Optional memory type filter. |
| 32 | + Returns: |
| 33 | + List of conflict pairs (each pair is a tuple: (memory, candidate)). |
| 34 | + """ |
| 35 | + # 1. Search for similar memories based on embedding |
| 36 | + embedding = memory.metadata.embedding |
| 37 | + embedding_candidates_info = self.graph_store.search_by_embedding( |
| 38 | + embedding, top_k=top_k, scope=scope |
| 39 | + ) |
| 40 | + # 2. Filter based on similarity threshold |
| 41 | + embedding_candidates_ids = [ |
| 42 | + info["id"] |
| 43 | + for info in embedding_candidates_info |
| 44 | + if info["score"] >= self.EMBEDDING_THRESHOLD and info["id"] != memory.id |
| 45 | + ] |
| 46 | + # 3. Judge conflicts using LLM |
| 47 | + embedding_candidates = self.graph_store.get_nodes(embedding_candidates_ids) |
| 48 | + conflict_pairs = [] |
| 49 | + for embedding_candidate in embedding_candidates: |
| 50 | + embedding_candidate = TextualMemoryItem.from_dict(embedding_candidate) |
| 51 | + prompt = [ |
| 52 | + {"role": "system", "content": "You are a conflict detector for memory items."}, |
| 53 | + { |
| 54 | + "role": "user", |
| 55 | + "content": f""" |
| 56 | +You are given two plaintext statements. Determine if these two statements are factually contradictory. Respond with only "yes" if they contradict each other, or "no" if they do not contradict each other. Do not provide any explanation or additional text. |
| 57 | +Statement 1: {memory.memory!s} |
| 58 | +Statement 2: {embedding_candidate.memory!s}""", |
| 59 | + }, |
| 60 | + ] |
| 61 | + result = self.llm.generate(prompt).strip().lower() |
| 62 | + if "yes" in result.lower(): |
| 63 | + conflict_pairs.append([memory, embedding_candidate]) |
| 64 | + if len(conflict_pairs): |
| 65 | + conflict_text = "\n".join( |
| 66 | + f'"{pair[0].memory!s}" <==CONFLICT==> "{pair[1].memory!s}"' |
| 67 | + for pair in conflict_pairs |
| 68 | + ) |
| 69 | + logger.warning( |
| 70 | + f"Detected {len(conflict_pairs)} conflicts for memory {memory.id}\n {conflict_text}" |
| 71 | + ) |
| 72 | + for pair in conflict_pairs: |
| 73 | + print(pair[0].id, pair[1].id) |
| 74 | + return conflict_pairs |
| 75 | + |
| 76 | + |
| 77 | +class ConflictResolver: |
| 78 | + def __init__(self, graph_store: Neo4jGraphDB, llm: BaseLLM, embedder: BaseEmbedder): |
| 79 | + self.graph_store = graph_store |
| 80 | + self.llm = llm |
| 81 | + self.embedder = embedder |
| 82 | + |
| 83 | + def resolve(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem) -> None: |
| 84 | + """ |
| 85 | + Resolve detected conflicts between two memory items using LLM fusion. |
| 86 | + Args: |
| 87 | + memory_a: The first conflicting memory item. |
| 88 | + memory_b: The second conflicting memory item. |
| 89 | + Returns: |
| 90 | + A fused TextualMemoryItem representing the resolved memory. |
| 91 | + """ |
| 92 | + |
| 93 | + # ———————————— 1. LLM generate fused memory ———————————— |
| 94 | + metadata_for_resolve = ["key", "background", "confidence", "updated_at"] |
| 95 | + metadata_1 = memory_a.metadata.model_dump_json(include=metadata_for_resolve) |
| 96 | + metadata_2 = memory_b.metadata.model_dump_json(include=metadata_for_resolve) |
| 97 | + prompt = [ |
| 98 | + { |
| 99 | + "role": "system", |
| 100 | + "content": "", |
| 101 | + }, |
| 102 | + { |
| 103 | + "role": "user", |
| 104 | + "content": CONFLICT_RESOLVER_PROMPT.format( |
| 105 | + statement_1=memory_a.memory, |
| 106 | + metadata_1=metadata_1, |
| 107 | + statement_2=memory_b.memory, |
| 108 | + metadata_2=metadata_2, |
| 109 | + ), |
| 110 | + }, |
| 111 | + ] |
| 112 | + response = self.llm.generate(prompt).strip() |
| 113 | + |
| 114 | + # ———————————— 2. Parse the response ———————————— |
| 115 | + try: |
| 116 | + answer = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) |
| 117 | + answer = answer.group(1).strip() |
| 118 | + # —————— 2.1 Can't resolve conflict, hard update by comparing timestamp ———— |
| 119 | + if len(answer) <= 10 and "no" in answer.lower(): |
| 120 | + logger.warning( |
| 121 | + f"Conflict between {memory_a.id} and {memory_b.id} could not be resolved. " |
| 122 | + ) |
| 123 | + self._hard_update(memory_a, memory_b) |
| 124 | + # —————— 2.2 Conflict resolved, update metadata and memory ———— |
| 125 | + else: |
| 126 | + fixed_metadata = self._merge_metadata(answer, memory_a.metadata, memory_b.metadata) |
| 127 | + merged_memory = TextualMemoryItem(memory=answer, metadata=fixed_metadata) |
| 128 | + logger.info(f"Resolved result: {merged_memory}") |
| 129 | + self._resolve_in_graph(memory_a, memory_b, merged_memory) |
| 130 | + except json.decoder.JSONDecodeError: |
| 131 | + logger.error(f"Failed to parse LLM response: {response}") |
| 132 | + |
| 133 | + def _hard_update(self, memory_a: TextualMemoryItem, memory_b: TextualMemoryItem): |
| 134 | + """ |
| 135 | + Hard update: compare updated_at, keep the newer one, overwrite the older one's metadata. |
| 136 | + """ |
| 137 | + time_a = datetime.fromisoformat(memory_a.metadata.updated_at) |
| 138 | + time_b = datetime.fromisoformat(memory_b.metadata.updated_at) |
| 139 | + |
| 140 | + newer_mem = memory_a if time_a >= time_b else memory_b |
| 141 | + older_mem = memory_b if time_a >= time_b else memory_a |
| 142 | + |
| 143 | + self.graph_store.delete_node(older_mem.id) |
| 144 | + logger.warning( |
| 145 | + f"Delete older memory {older_mem.id}: <{older_mem.memory}> due to conflict with {newer_mem.id}: <{newer_mem.memory}>" |
| 146 | + ) |
| 147 | + |
| 148 | + def _resolve_in_graph( |
| 149 | + self, |
| 150 | + conflict_a: TextualMemoryItem, |
| 151 | + conflict_b: TextualMemoryItem, |
| 152 | + merged: TextualMemoryItem, |
| 153 | + ): |
| 154 | + edges_a = self.graph_store.get_edges(conflict_a.id, type="ANY", direction="ANY") |
| 155 | + edges_b = self.graph_store.get_edges(conflict_b.id, type="ANY", direction="ANY") |
| 156 | + all_edges = edges_a + edges_b |
| 157 | + |
| 158 | + self.graph_store.add_node( |
| 159 | + merged.id, merged.memory, merged.metadata.model_dump(exclude_none=True) |
| 160 | + ) |
| 161 | + |
| 162 | + for edge in all_edges: |
| 163 | + new_from = merged.id if edge["from"] in (conflict_a.id, conflict_b.id) else edge["from"] |
| 164 | + new_to = merged.id if edge["to"] in (conflict_a.id, conflict_b.id) else edge["to"] |
| 165 | + if new_from == new_to: |
| 166 | + continue |
| 167 | + # Check if the edge already exists before adding |
| 168 | + if not self.graph_store.edge_exists(new_from, new_to, edge["type"], direction="ANY"): |
| 169 | + self.graph_store.add_edge(new_from, new_to, edge["type"]) |
| 170 | + |
| 171 | + self.graph_store.delete_node(conflict_a.id) |
| 172 | + self.graph_store.delete_node(conflict_b.id) |
| 173 | + logger.debug( |
| 174 | + f"Remove {conflict_a.id} and {conflict_b.id}, and inherit their edges to {merged.id}." |
| 175 | + ) |
| 176 | + |
| 177 | + def _merge_metadata( |
| 178 | + self, |
| 179 | + memory: str, |
| 180 | + metadata_a: TreeNodeTextualMemoryMetadata, |
| 181 | + metadata_b: TreeNodeTextualMemoryMetadata, |
| 182 | + ) -> TreeNodeTextualMemoryMetadata: |
| 183 | + metadata_1 = metadata_a.model_dump() |
| 184 | + metadata_2 = metadata_b.model_dump() |
| 185 | + merged_metadata = { |
| 186 | + "sources": (metadata_1["sources"] or []) + (metadata_2["sources"] or []), |
| 187 | + "embedding": self.embedder.embed([memory])[0], |
| 188 | + "update_at": datetime.now().isoformat(), |
| 189 | + "created_at": datetime.now().isoformat(), |
| 190 | + } |
| 191 | + for key in metadata_1: |
| 192 | + if key in merged_metadata: |
| 193 | + continue |
| 194 | + merged_metadata[key] = ( |
| 195 | + metadata_1[key] if metadata_1[key] is not None else metadata_2[key] |
| 196 | + ) |
| 197 | + return TreeNodeTextualMemoryMetadata.model_validate(merged_metadata) |
| 198 | + |
| 199 | + |
| 200 | +CONFLICT_RESOLVER_PROMPT = """You are given two facts that conflict with each other. You are also given some contextual metadata of them. Your task is to analyze the two facts in light of the contextual metadata and try to reconcile them into a single, consistent, non-conflicting fact. |
| 201 | +- Don't output any explanation or additional text, just the final reconciled fact, try to be objective and remain independent of the context, don't use pronouns. |
| 202 | +- Try to judge facts by using its time, confidence etc. |
| 203 | +- Try to retain as much information as possible from the perspective of time. |
| 204 | +If the conflict cannot be resolved, output <answer>No</answer>. Otherwise, output the fused, consistent fact in enclosed with <answer></answer> tags. |
| 205 | +
|
| 206 | +Output Example 1: |
| 207 | +<answer>No</answer> |
| 208 | +
|
| 209 | +Output Example 2: |
| 210 | +<answer> ... </answer> |
| 211 | +
|
| 212 | +Now reconcile the following two facts: |
| 213 | +Statement 1: {statement_1} |
| 214 | +Metadata 1: {metadata_1} |
| 215 | +Statement 2: {statement_2} |
| 216 | +Metadata 2: {metadata_2} |
| 217 | +""" |
0 commit comments