Skip to content

Commit ea17f6f

Browse files
Fix dedup handling in simple search
1 parent ce297ba commit ea17f6f

File tree

9 files changed

+157
-129
lines changed

9 files changed

+157
-129
lines changed

src/memos/api/handlers/formatters_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]:
2929
return list(running) if running else []
3030

3131

32-
def format_memory_item(memory_data: Any) -> dict[str, Any]:
32+
def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
3333
"""
3434
Format a single memory item for API response.
3535
@@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
4747
ref_id = f"[{memory_id.split('-')[0]}]"
4848

4949
memory["ref_id"] = ref_id
50-
memory["metadata"]["embedding"] = []
50+
if not include_embedding:
51+
memory["metadata"]["embedding"] = []
5152
memory["metadata"]["sources"] = []
5253
memory["metadata"]["usage"] = []
5354
memory["metadata"]["ref_id"] = ref_id

src/memos/api/handlers/search_handler.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from typing import Any
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APISearchRequest, SearchResponse
1012
from memos.log import get_logger
13+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
14+
cosine_similarity_matrix,
15+
)
1116
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1217
from memos.multi_mem_cube.single_cube import SingleCubeView
1318
from memos.multi_mem_cube.views import MemCubeView
@@ -53,6 +58,9 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5358
cube_view = self._build_cube_view(search_req)
5459

5560
results = cube_view.search_memories(search_req)
61+
if search_req.dedup == "sim":
62+
results = self._dedup_text_memories(results, search_req.top_k)
63+
self._strip_embeddings(results)
5664

5765
self.logger.info(
5866
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -63,6 +71,107 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
6371
data=results,
6472
)
6573

74+
def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
75+
buckets = results.get("text_mem", [])
76+
if not buckets:
77+
return results
78+
79+
flat: list[tuple[int, dict[str, Any]]] = []
80+
for bucket_idx, bucket in enumerate(buckets):
81+
for mem in bucket.get("memories", []):
82+
flat.append((bucket_idx, mem))
83+
84+
if len(flat) <= 1:
85+
return results
86+
87+
embeddings = self._extract_embeddings([mem for _, mem in flat])
88+
if embeddings is None:
89+
documents = [mem.get("memory", "") for _, mem in flat]
90+
embeddings = self.searcher.embedder.embed(documents)
91+
92+
similarity_matrix = cosine_similarity_matrix(embeddings)
93+
94+
indices_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
95+
for flat_index, (bucket_idx, _) in enumerate(flat):
96+
indices_by_bucket[bucket_idx].append(flat_index)
97+
98+
selected_global: list[int] = []
99+
selected_by_bucket: dict[int, list[int]] = {i: [] for i in range(len(buckets))}
100+
101+
for bucket_idx in range(len(buckets)):
102+
for idx in indices_by_bucket.get(bucket_idx, []):
103+
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
104+
break
105+
if self._is_unrelated(idx, selected_global, similarity_matrix, 0.85):
106+
selected_by_bucket[bucket_idx].append(idx)
107+
selected_global.append(idx)
108+
109+
for bucket_idx in range(len(buckets)):
110+
if len(selected_by_bucket[bucket_idx]) >= min(
111+
target_top_k, len(indices_by_bucket[bucket_idx])
112+
):
113+
continue
114+
remaining_indices = [
115+
idx
116+
for idx in indices_by_bucket.get(bucket_idx, [])
117+
if idx not in selected_by_bucket[bucket_idx]
118+
]
119+
if not remaining_indices:
120+
continue
121+
# Fill to target_top_k with the least-similar candidates to preserve diversity.
122+
remaining_indices.sort(
123+
key=lambda idx: self._max_similarity(idx, selected_global, similarity_matrix)
124+
)
125+
for idx in remaining_indices:
126+
if len(selected_by_bucket[bucket_idx]) >= target_top_k:
127+
break
128+
selected_by_bucket[bucket_idx].append(idx)
129+
130+
for bucket_idx, bucket in enumerate(buckets):
131+
selected_indices = selected_by_bucket.get(bucket_idx, [])
132+
bucket["memories"] = [flat[i][1] for i in selected_indices[:target_top_k]]
133+
return results
134+
135+
@staticmethod
136+
def _is_unrelated(
137+
index: int,
138+
selected_indices: list[int],
139+
similarity_matrix: list[list[float]],
140+
similarity_threshold: float,
141+
) -> bool:
142+
return all(similarity_matrix[index][j] <= similarity_threshold for j in selected_indices)
143+
144+
@staticmethod
145+
def _max_similarity(
146+
index: int, selected_indices: list[int], similarity_matrix: list[list[float]]
147+
) -> float:
148+
if not selected_indices:
149+
return 0.0
150+
return max(similarity_matrix[index][j] for j in selected_indices)
151+
152+
@staticmethod
153+
def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
154+
embeddings: list[list[float]] = []
155+
for mem in memories:
156+
embedding = mem.get("metadata", {}).get("embedding")
157+
if not embedding:
158+
return None
159+
embeddings.append(embedding)
160+
return embeddings
161+
162+
@staticmethod
163+
def _strip_embeddings(results: dict[str, Any]) -> None:
164+
for bucket in results.get("text_mem", []):
165+
for mem in bucket.get("memories", []):
166+
metadata = mem.get("metadata", {})
167+
if "embedding" in metadata:
168+
metadata["embedding"] = []
169+
for bucket in results.get("tool_mem", []):
170+
for mem in bucket.get("memories", []):
171+
metadata = mem.get("metadata", {})
172+
if "embedding" in metadata:
173+
metadata["embedding"] = []
174+
66175
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
67176
"""
68177
Normalize target cube ids from search_req.

src/memos/api/start_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from typing import Any, Generic, Literal, TypeVar
4+
from typing import Any, Generic, TypeVar
55

66
from dotenv import load_dotenv
77
from fastapi import FastAPI
@@ -145,14 +145,6 @@ class SearchRequest(BaseRequest):
145145
description="List of cube IDs to search in",
146146
json_schema_extra={"example": ["cube123", "cube456"]},
147147
)
148-
dedup: Literal["no", "sim"] | None = Field(
149-
None,
150-
description=(
151-
"Optional dedup option for textual memories. "
152-
"Use 'no' for no dedup, 'sim' for similarity dedup. "
153-
"If None, default exact-text dedup is applied."
154-
),
155-
)
156148

157149

158150
class MemCubeRegister(BaseRequest):
@@ -357,7 +349,6 @@ async def search_memories(search_req: SearchRequest):
357349
query=search_req.query,
358350
user_id=search_req.user_id,
359351
install_cube_ids=search_req.install_cube_ids,
360-
dedup=search_req.dedup,
361352
)
362353
return SearchResponse(message="Search completed successfully", data=result)
363354

src/memos/mem_os/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ def search(
551551
internet_search: bool = False,
552552
moscube: bool = False,
553553
session_id: str | None = None,
554-
dedup: str | None = None,
555554
**kwargs,
556555
) -> MOSSearchResult:
557556
"""
@@ -626,7 +625,6 @@ def search_textual_memory(cube_id, cube):
626625
},
627626
moscube=moscube,
628627
search_filter=search_filter,
629-
dedup=dedup,
630628
)
631629
search_time_end = time.time()
632630
logger.info(

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def mix_search_memories(
190190
)
191191
memories = merged_memories[: search_req.top_k]
192192

193-
formatted_memories = [format_textual_memory_item(item) for item in memories]
193+
formatted_memories = [
194+
format_textual_memory_item(item, include_embedding=search_req.dedup == "sim")
195+
for item in memories
196+
]
194197
self.submit_memory_history_async_task(
195198
search_req=search_req,
196199
user_context=user_context,
@@ -234,7 +237,10 @@ def update_search_memories_to_redis(
234237
mem_cube=self.mem_cube,
235238
mode=SearchMode.FAST,
236239
)
237-
formatted_memories = [format_textual_memory_item(data) for data in memories]
240+
formatted_memories = [
241+
format_textual_memory_item(data, include_embedding=search_req.dedup == "sim")
242+
for data in memories
243+
]
238244
else:
239245
memories = [
240246
TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]

src/memos/mem_scheduler/utils/api_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
from memos.memories.textual.tree import TextualMemoryItem
77

88

9-
def format_textual_memory_item(memory_data: Any) -> dict[str, Any]:
9+
def format_textual_memory_item(
10+
memory_data: Any, include_embedding: bool = False
11+
) -> dict[str, Any]:
1012
"""Format a single memory item for API response."""
1113
memory = memory_data.model_dump()
1214
memory_id = memory["id"]
1315
ref_id = f"[{memory_id.split('-')[0]}]"
1416

1517
memory["ref_id"] = ref_id
16-
memory["metadata"]["embedding"] = []
18+
if not include_embedding:
19+
memory["metadata"]["embedding"] = []
1720
memory["metadata"]["sources"] = []
1821
memory["metadata"]["ref_id"] = ref_id
1922
memory["metadata"]["id"] = memory_id

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def deep_search(
239239
user_name: str | None = None,
240240
**kwargs,
241241
):
242-
dedup = kwargs.get("dedup")
243242
previous_retrieval_phrases = [query]
244243
retrieved_memories = self.retrieve(
245244
query=query,
@@ -255,7 +254,6 @@ def deep_search(
255254
top_k=top_k,
256255
user_name=user_name,
257256
info=info,
258-
dedup=dedup,
259257
)
260258
if len(memories) == 0:
261259
logger.warning("Requirements not met; returning memories as-is.")

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 10 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def post_retrieve(
124124
):
125125
if dedup == "no":
126126
deduped = retrieved_results
127-
elif dedup == "sim":
128-
deduped = self._deduplicate_similar_results(retrieved_results)
129127
else:
130128
deduped = self._deduplicate_results(retrieved_results)
131129
final_results = self._sort_and_trim(
@@ -180,11 +178,7 @@ def search(
180178
if kwargs.get("plugin", False):
181179
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
182180
retrieved_results = self._retrieve_simple(
183-
query=query,
184-
top_k=top_k,
185-
search_filter=search_filter,
186-
user_name=user_name,
187-
dedup=dedup,
181+
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
188182
)
189183
else:
190184
retrieved_results = self.retrieve(
@@ -213,7 +207,7 @@ def search(
213207
plugin=kwargs.get("plugin", False),
214208
search_tool_memory=search_tool_memory,
215209
tool_mem_top_k=tool_mem_top_k,
216-
dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup,
210+
dedup=dedup,
217211
)
218212

219213
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -296,50 +290,6 @@ def _parse_task(
296290

297291
return parsed_goal, query_embedding, context, query
298292

299-
@timed
300-
def _retrieve_simple(
301-
self,
302-
query: str,
303-
top_k: int,
304-
search_filter: dict | None = None,
305-
user_name: str | None = None,
306-
dedup: str | None = None,
307-
**kwargs,
308-
):
309-
"""Retrieve from by keywords and embedding"""
310-
query_words = []
311-
if self.tokenizer:
312-
query_words = self.tokenizer.tokenize_mixed(query)
313-
else:
314-
query_words = query.strip().split()
315-
query_words = [query, *query_words]
316-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
317-
query_embeddings = self.embedder.embed(query_words)
318-
319-
items = self.graph_retriever.retrieve_from_mixed(
320-
top_k=top_k * 2,
321-
memory_scope=None,
322-
query_embedding=query_embeddings,
323-
search_filter=search_filter,
324-
user_name=user_name,
325-
use_fast_graph=self.use_fast_graph,
326-
)
327-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
328-
documents = [getattr(item, "memory", "") for item in items]
329-
documents_embeddings = self.embedder.embed(documents)
330-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
331-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
332-
selected_items = [items[i] for i in selected_indices]
333-
logger.info(
334-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
335-
)
336-
return self.reranker.rerank(
337-
query=query,
338-
query_embedding=query_embeddings[0],
339-
graph_results=selected_items,
340-
top_k=top_k,
341-
)
342-
343293
@timed
344294
def _retrieve_paths(
345295
self,
@@ -723,17 +673,14 @@ def _retrieve_simple(
723673
user_name=user_name,
724674
)
725675
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
726-
if dedup == "no":
727-
selected_items = items
728-
else:
729-
documents = [getattr(item, "memory", "") for item in items]
730-
documents_embeddings = self.embedder.embed(documents)
731-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
732-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
733-
selected_items = [items[i] for i in selected_indices]
734-
logger.info(
735-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
736-
)
676+
documents = [getattr(item, "memory", "") for item in items]
677+
documents_embeddings = self.embedder.embed(documents)
678+
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
679+
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
680+
selected_items = [items[i] for i in selected_indices]
681+
logger.info(
682+
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
683+
)
737684
return self.reranker.rerank(
738685
query=query,
739686
query_embedding=query_embeddings[0],
@@ -750,26 +697,6 @@ def _deduplicate_results(self, results):
750697
deduped[item.memory] = (item, score)
751698
return list(deduped.values())
752699

753-
@timed
754-
def _deduplicate_similar_results(
755-
self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85
756-
):
757-
"""Deduplicate results by semantic similarity while keeping higher scores."""
758-
if len(results) <= 1:
759-
return results
760-
761-
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)
762-
documents = [getattr(item, "memory", "") for item, _ in sorted_results]
763-
embeddings = self.embedder.embed(documents)
764-
similarity_matrix = cosine_similarity_matrix(embeddings)
765-
766-
selected_indices: list[int] = []
767-
for i in range(len(sorted_results)):
768-
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769-
selected_indices.append(i)
770-
771-
return [sorted_results[i] for i in selected_indices]
772-
773700
@timed
774701
def _sort_and_trim(
775702
self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6

0 commit comments

Comments
 (0)