Skip to content

Commit 336a2be

Browse files
authored
feat: add dedup search param (#788)
* Add dedup option to search pipeline * Fix dedup handling in simple search * feat: optimize memory search deduplication and fix parsing bugs - Tune similarity threshold to 0.92 for 'dedup=sim' to preserve subtle semantic nuances. - Implement recall expansion (5x Top-K) when deduplicating to ensure output diversity. - Remove aggressive filling logic to strictly enforce the similarity threshold. - Fix attribute error in MultiModalStructMemReader by correctly importing parse_json_result. - Replace fragile eval() with robust parse_json_result in TaskGoalParser to handle JSON booleans. --------- Co-authored-by: [email protected] <>
1 parent fac1aa7 commit 336a2be

File tree

11 files changed

+171
-63
lines changed

11 files changed

+171
-63
lines changed

src/memos/api/handlers/formatters_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]:
2929
return list(running) if running else []
3030

3131

32-
def format_memory_item(memory_data: Any) -> dict[str, Any]:
32+
def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
3333
"""
3434
Format a single memory item for API response.
3535
@@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
4747
ref_id = f"[{memory_id.split('-')[0]}]"
4848

4949
memory["ref_id"] = ref_id
50-
memory["metadata"]["embedding"] = []
50+
if not include_embedding:
51+
memory["metadata"]["embedding"] = []
5152
memory["metadata"]["sources"] = []
5253
memory["metadata"]["usage"] = []
5354
memory["metadata"]["ref_id"] = ref_id

src/memos/api/handlers/search_handler.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from typing import Any
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APISearchRequest, SearchResponse
1012
from memos.log import get_logger
13+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
14+
cosine_similarity_matrix,
15+
)
1116
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1217
from memos.multi_mem_cube.single_cube import SingleCubeView
1318
from memos.multi_mem_cube.views import MemCubeView
@@ -50,9 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5055
"""
5156
self.logger.info(f"[SearchHandler] Search Req is: {search_req}")
5257

58+
# Increase recall pool if deduplication is enabled to ensure diversity
59+
original_top_k = search_req.top_k
60+
if search_req.dedup == "sim":
61+
search_req.top_k = original_top_k * 5
62+
5363
cube_view = self._build_cube_view(search_req)
5464

5565
results = cube_view.search_memories(search_req)
66+
if search_req.dedup == "sim":
67+
results = self._dedup_text_memories(results, original_top_k)
68+
self._strip_embeddings(results)
69+
# Restore original top_k for downstream logic or response metadata
70+
search_req.top_k = original_top_k
5671

5772
self.logger.info(
5873
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -63,6 +78,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
6378
data=results,
6479
)
6580

81+
def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
82+
buckets = results.get("text_mem", [])
83+
if not buckets:
84+
return results
85+
86+
flat: list[tuple[int, dict[str, Any], float]] = []
87+
for bucket_idx, bucket in enumerate(buckets):
88+
for mem in bucket.get("memories", []):
89+
score = mem.get("metadata", {}).get("relativity", 0.0)
90+
flat.append((bucket_idx, mem, score))
91+
92+
if len(flat) <= 1:
93+
return results
94+
95+
embeddings = self._extract_embeddings([mem for _, mem, _ in flat])
96+
if embeddings is None:
97+
documents = [mem.get("memory", "") for _, mem, _ in flat]
98+
embeddings = self.searcher.embedder.embed(documents)
99+
100+
similarity_matrix = cosine_similarity_matrix(embeddings)
101+
102+
indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
103+
for flat_index, (bucket_idx, _, _) in enumerate(flat):
104+
indices_by_bucket[bucket_idx].append(flat_index)
105+
106+
selected_global: list[int] = []
107+
selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
108+
109+
ordered_indices = sorted(range(len(flat)), key=lambda idx: flat[idx][2], reverse=True)
110+
for idx in ordered_indices:
111+
bucket_idx = flat[idx][0]
112+
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
113+
continue
114+
# Use 0.92 threshold strictly
115+
if self._is_unrelated(idx, selected_global, similarity_matrix, 0.92):
116+
selected_by_bucket[bucket_idx].append(idx)
117+
selected_global.append(idx)
118+
119+
# Removed the 'filling' logic that was pulling back similar items.
120+
# Now it will only return items that truly pass the 0.92 threshold,
121+
# up to target_top_k.
122+
123+
for bucket_idx, bucket in enumerate(buckets):
124+
selected_indices = selected_by_bucket.get(bucket_idx, [])
125+
bucket["memories"] = [flat[i][1] for i in selected_indices]
126+
return results
127+
128+
@staticmethod
129+
def _is_unrelated(
130+
index: int,
131+
selected_indices: list[int],
132+
similarity_matrix: list[list[float]],
133+
similarity_threshold: float,
134+
) -> bool:
135+
return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices)
136+
137+
@staticmethod
138+
def _max_similarity(
139+
index: int, selected_indices: list[int], similarity_matrix: list[list[float]]
140+
) -> float:
141+
if not selected_indices:
142+
return 0.0
143+
return max(similarity_matrix[index][j] for j in selected_indices)
144+
145+
@staticmethod
146+
def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
147+
embeddings: list[list[float]] = []
148+
for mem in memories:
149+
embedding = mem.get("metadata", {}).get("embedding")
150+
if not embedding:
151+
return None
152+
embeddings.append(embedding)
153+
return embeddings
154+
155+
@staticmethod
156+
def _strip_embeddings(results: dict[str, Any]) -> None:
157+
for bucket in results.get("text_mem", []):
158+
for mem in bucket.get("memories", []):
159+
metadata = mem.get("metadata", {})
160+
if "embedding" in metadata:
161+
metadata["embedding"] = []
162+
for bucket in results.get("tool_mem", []):
163+
for mem in bucket.get("memories", []):
164+
metadata = mem.get("metadata", {})
165+
if "embedding" in metadata:
166+
metadata["embedding"] = []
167+
66168
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
67169
"""
68170
Normalize target cube ids from search_req.

src/memos/api/product_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest):
319319
description="Number of textual memories to retrieve (top-K). Default: 10.",
320320
)
321321

322+
dedup: Literal["no", "sim"] | None = Field(
323+
None,
324+
description=(
325+
"Optional dedup option for textual memories. "
326+
"Use 'no' for no dedup, 'sim' for similarity dedup. "
327+
"If None, default exact-text dedup is applied."
328+
),
329+
)
330+
322331
pref_top_k: int = Field(
323332
6,
324333
ge=0,

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
1111
from memos.mem_reader.read_multi_modal.base import _derive_key
1212
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
13+
from memos.mem_reader.utils import parse_json_result
1314
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
1415
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1516
from memos.types import MessagesType
@@ -377,7 +378,7 @@ def _get_llm_response(
377378
messages = [{"role": "user", "content": prompt}]
378379
try:
379380
response_text = self.llm.generate(messages)
380-
response_json = self.parse_json_result(response_text)
381+
response_json = parse_json_result(response_text)
381382
except Exception as e:
382383
logger.error(f"[LLM] Exception during chat generation: {e}")
383384
response_json = {

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,14 @@ def mix_search_memories(
186186
info=info,
187187
search_tool_memory=search_req.search_tool_memory,
188188
tool_mem_top_k=search_req.tool_mem_top_k,
189+
dedup=search_req.dedup,
189190
)
190191
memories = merged_memories[: search_req.top_k]
191192

192-
formatted_memories = [format_textual_memory_item(item) for item in memories]
193+
formatted_memories = [
194+
format_textual_memory_item(item, include_embedding=search_req.dedup == "sim")
195+
for item in memories
196+
]
193197
self.submit_memory_history_async_task(
194198
search_req=search_req,
195199
user_context=user_context,
@@ -233,7 +237,10 @@ def update_search_memories_to_redis(
233237
mem_cube=self.mem_cube,
234238
mode=SearchMode.FAST,
235239
)
236-
formatted_memories = [format_textual_memory_item(data) for data in memories]
240+
formatted_memories = [
241+
format_textual_memory_item(data, include_embedding=search_req.dedup == "sim")
242+
for data in memories
243+
]
237244
else:
238245
memories = [
239246
TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]

src/memos/mem_scheduler/utils/api_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
from memos.memories.textual.tree import TextualMemoryItem
77

88

9-
def format_textual_memory_item(memory_data: Any) -> dict[str, Any]:
9+
def format_textual_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
1010
"""Format a single memory item for API response."""
1111
memory = memory_data.model_dump()
1212
memory_id = memory["id"]
1313
ref_id = f"[{memory_id.split('-')[0]}]"
1414

1515
memory["ref_id"] = ref_id
16-
memory["metadata"]["embedding"] = []
16+
if not include_embedding:
17+
memory["metadata"]["embedding"] = []
1718
memory["metadata"]["sources"] = []
1819
memory["metadata"]["ref_id"] = ref_id
1920
memory["metadata"]["id"] = memory_id

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def search(
161161
user_name: str | None = None,
162162
search_tool_memory: bool = False,
163163
tool_mem_top_k: int = 6,
164+
dedup: str | None = None,
164165
**kwargs,
165166
) -> list[TextualMemoryItem]:
166167
"""Search for memories based on a query.
@@ -207,6 +208,7 @@ def search(
207208
user_name=user_name,
208209
search_tool_memory=search_tool_memory,
209210
tool_mem_top_k=tool_mem_top_k,
211+
dedup=dedup,
210212
**kwargs,
211213
)
212214

src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55
from typing import Any
66

7-
87
import numpy as np
98

109
from memos.dependency import require_python_package

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,13 @@ def post_retrieve(
119119
info=None,
120120
search_tool_memory: bool = False,
121121
tool_mem_top_k: int = 6,
122+
dedup: str | None = None,
122123
plugin=False,
123124
):
124-
deduped = self._deduplicate_results(retrieved_results)
125+
if dedup == "no":
126+
deduped = retrieved_results
127+
else:
128+
deduped = self._deduplicate_results(retrieved_results)
125129
final_results = self._sort_and_trim(
126130
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
127131
)
@@ -141,6 +145,7 @@ def search(
141145
user_name: str | None = None,
142146
search_tool_memory: bool = False,
143147
tool_mem_top_k: int = 6,
148+
dedup: str | None = None,
144149
**kwargs,
145150
) -> list[TextualMemoryItem]:
146151
"""
@@ -202,6 +207,7 @@ def search(
202207
plugin=kwargs.get("plugin", False),
203208
search_tool_memory=search_tool_memory,
204209
tool_mem_top_k=tool_mem_top_k,
210+
dedup=dedup,
205211
)
206212

207213
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -284,49 +290,6 @@ def _parse_task(
284290

285291
return parsed_goal, query_embedding, context, query
286292

287-
@timed
288-
def _retrieve_simple(
289-
self,
290-
query: str,
291-
top_k: int,
292-
search_filter: dict | None = None,
293-
user_name: str | None = None,
294-
**kwargs,
295-
):
296-
"""Retrieve from by keywords and embedding"""
297-
query_words = []
298-
if self.tokenizer:
299-
query_words = self.tokenizer.tokenize_mixed(query)
300-
else:
301-
query_words = query.strip().split()
302-
query_words = [query, *query_words]
303-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
304-
query_embeddings = self.embedder.embed(query_words)
305-
306-
items = self.graph_retriever.retrieve_from_mixed(
307-
top_k=top_k * 2,
308-
memory_scope=None,
309-
query_embedding=query_embeddings,
310-
search_filter=search_filter,
311-
user_name=user_name,
312-
use_fast_graph=self.use_fast_graph,
313-
)
314-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
315-
documents = [getattr(item, "memory", "") for item in items]
316-
documents_embeddings = self.embedder.embed(documents)
317-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
318-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
319-
selected_items = [items[i] for i in selected_indices]
320-
logger.info(
321-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
322-
)
323-
return self.reranker.rerank(
324-
query=query,
325-
query_embedding=query_embeddings[0],
326-
graph_results=selected_items,
327-
top_k=top_k,
328-
)
329-
330293
@timed
331294
def _retrieve_paths(
332295
self,

src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from memos.llms.base import BaseLLM
66
from memos.log import get_logger
77
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
8-
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer
8+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
9+
FastTokenizer,
10+
parse_json_result,
11+
)
912
from memos.memories.textual.tree_text_memory.retrieve.utils import TASK_PARSE_PROMPT
1013

1114

@@ -111,8 +114,10 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
111114
for attempt_times in range(attempts):
112115
try:
113116
context = kwargs.get("context", "")
114-
response = response.replace("```", "").replace("json", "").strip()
115-
response_json = eval(response)
117+
response_json = parse_json_result(response)
118+
if not response_json:
119+
raise ValueError("Parsed JSON is empty")
120+
116121
return ParsedTaskGoal(
117122
memories=response_json.get("memories", []),
118123
keys=response_json.get("keys", []),
@@ -123,6 +128,8 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
123128
context=context,
124129
)
125130
except Exception as e:
126-
raise ValueError(
127-
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
128-
) from e
131+
if attempt_times == attempts - 1:
132+
raise ValueError(
133+
f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts}"
134+
) from e
135+
continue

0 commit comments

Comments
 (0)