Skip to content

Commit 6d7f410

Browse files
feat: optimize memory search deduplication and fix parsing bugs
- Tune similarity threshold to 0.92 for 'dedup=sim' to preserve subtle semantic nuances. - Implement recall expansion (5x Top-K) when deduplicating to ensure output diversity. - Remove aggressive filling logic to strictly enforce the similarity threshold. - Fix attribute error in MultiModalStructMemReader by correctly importing parse_json_result. - Replace fragile eval() with robust parse_json_result in TaskGoalParser to handle JSON booleans.
1 parent 0f7f84a commit 6d7f410

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

src/memos/api/handlers/search_handler.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5555
"""
5656
self.logger.info(f"[SearchHandler] Search Req is: {search_req}")
5757

58+
# Increase recall pool if deduplication is enabled to ensure diversity
59+
original_top_k = search_req.top_k
60+
if search_req.dedup == "sim":
61+
search_req.top_k = original_top_k * 5
62+
5863
cube_view = self._build_cube_view(search_req)
5964

6065
results = cube_view.search_memories(search_req)
6166
if search_req.dedup == "sim":
62-
results = self._dedup_text_memories(results, search_req.top_k)
67+
results = self._dedup_text_memories(results, original_top_k)
6368
self._strip_embeddings(results)
69+
# Restore original top_k for downstream logic or response metadata
70+
search_req.top_k = original_top_k
6471

6572
self.logger.info(
6673
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -104,35 +111,18 @@ def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> di
104111
bucket_idx = flat[idx][0]
105112
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
106113
continue
107-
if self._is_unrelated(idx, selected_global, similarity_matrix, 0.85):
114+
# Use 0.92 threshold strictly
115+
if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92):
108116
selected_by_bucket[bucket_idx].append(idx)
109117
selected_global.append(idx)
110118

111-
for bucket_idx in range(len(buckets)):
112-
if len(selected_by_bucket[bucket_idx]) >= min(
113-
target_top_k, len(indices_by_bucket[bucket_idx])
114-
):
115-
continue
116-
remaining_indices = [
117-
idx
118-
for idx in indices_by_bucket.get(bucket_idx, [])
119-
if idx not in selected_by_bucket[bucket_idx]
120-
]
121-
if not remaining_indices:
122-
continue
123-
# Fill to target_top_k with the least-similar candidates to preserve diversity.
124-
remaining_indices.sort(
125-
key=lambda idx: self._max_similarity(idx, selected_global, similarity_matrix)
126-
)
127-
for idx in remaining_indices:
128-
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
129-
break
130-
selected_by_bucket[bucket_idx].append(idx)
131-
selected_global.append(idx)
119+
# Removed the 'filling' logic that was pulling back similar items.
120+
# Now it will only return items that truly pass the 0.92 threshold,
121+
# up to target_top_k.
132122

133123
for bucket_idx, bucket in enumerate(buckets):
134124
selected_indices = selected_by_bucket.get(bucket_idx, [])
135-
bucket["memories"] = [flat[i][1] for i in selected_indices[:target_top_k]]
125+
bucket["memories"] = [flat[i][1] for i in selected_indices]
136126
return results
137127

138128
@staticmethod

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
1111
from memos.mem_reader.read_multi_modal.base import _derive_key
1212
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
13+
from memos.mem_reader.utils import parse_json_result
1314
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
1415
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1516
from memos.types import MessagesType
@@ -377,7 +378,7 @@ def _get_llm_response(
377378
messages = [{"role": "user", "content": prompt}]
378379
try:
379380
response_text = self.llm.generate(messages)
380-
response_json = self.parse_json_result(response_text)
381+
response_json = parse_json_result(response_text)
381382
except Exception as e:
382383
logger.error(f"[LLM] Exception during chat generation: {e}")
383384
response_json = {

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55
from typing import Any
66

7-
87
import numpy as np
98

109
from memos.dependency import require_python_package

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from memos.llms.base import BaseLLM
66
from memos.log import get_logger
77
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
8-
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
8+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
9+
FastTokenizer,
10+
parse_json_result,
11+
)
912
from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT
1013

1114

@@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
111114
for attempt_times in range(attempts):
112115
try:
113116
context = kwargs.get("context", "")
114-
response = response.replace("```", "").replace("json", "").strip()
115-
response_json = eval(response)
117+
response_json = parse_json_result(response)
118+
if not response_json:
119+
raise ValueError("Parsed JSON is empty")
120+
116121
return ParsedTaskGoal(
117122
memories=response_json.get("memories", []),
118123
keys=response_json.get("keys", []),
@@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
123128
context=context,
124129
)
125130
except Exception as e:
126-
raise ValueError(
127-
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
128-
) from e
131+
if attempt_times == attempts - 1:
132+
raise ValueError(
133+
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}"
134+
) from e
135+
continue

0 commit comments

Comments
 (0)