Skip to content

Commit ce297ba

Browse files
Add dedup option to search pipeline
1 parent b11c768 commit ce297ba

File tree

8 files changed

+96
-12
lines changed

8 files changed

+96
-12
lines changed

src/memos/api/product_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,15 @@ class APISearchRequest(BaseRequest):
319319
description="Number of textual memories to retrieve (top-K). Default: 10.",
320320
)
321321

322+
dedup: Literal["no", "sim"] | None = Field(
323+
None,
324+
description=(
325+
"Optional dedup option for textual memories. "
326+
"Use 'no' for no dedup, 'sim' for similarity dedup. "
327+
"If None, default exact-text dedup is applied."
328+
),
329+
)
330+
322331
pref_top_k: int = Field(
323332
6,
324333
ge=0,

src/memos/api/start_api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from typing import Any, Generic, TypeVar
4+
from typing import Any, Generic, Literal, TypeVar
55

66
from dotenv import load_dotenv
77
from fastapi import FastAPI
@@ -145,6 +145,14 @@ class SearchRequest(BaseRequest):
145145
description="List of cube IDs to search in",
146146
json_schema_extra={"example": ["cube123", "cube456"]},
147147
)
148+
dedup: Literal["no", "sim"] | None = Field(
149+
None,
150+
description=(
151+
"Optional dedup option for textual memories. "
152+
"Use 'no' for no dedup, 'sim' for similarity dedup. "
153+
"If None, default exact-text dedup is applied."
154+
),
155+
)
148156

149157

150158
class MemCubeRegister(BaseRequest):
@@ -349,6 +357,7 @@ async def search_memories(search_req: SearchRequest):
349357
query=search_req.query,
350358
user_id=search_req.user_id,
351359
install_cube_ids=search_req.install_cube_ids,
360+
dedup=search_req.dedup,
352361
)
353362
return SearchResponse(message="Search completed successfully", data=result)
354363

src/memos/mem_os/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def search(
551551
internet_search: bool = False,
552552
moscube: bool = False,
553553
session_id: str | None = None,
554+
dedup: str | None = None,
554555
**kwargs,
555556
) -> MOSSearchResult:
556557
"""
@@ -625,6 +626,7 @@ def search_textual_memory(cube_id, cube):
625626
},
626627
moscube=moscube,
627628
search_filter=search_filter,
629+
dedup=dedup,
628630
)
629631
search_time_end = time.time()
630632
logger.info(

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def mix_search_memories(
186186
info=info,
187187
search_tool_memory=search_req.search_tool_memory,
188188
tool_mem_top_k=search_req.tool_mem_top_k,
189+
dedup=search_req.dedup,
189190
)
190191
memories = merged_memories[: search_req.top_k]
191192

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def search(
161161
user_name: str | None = None,
162162
search_tool_memory: bool = False,
163163
tool_mem_top_k: int = 6,
164+
dedup: str | None = None,
164165
**kwargs,
165166
) -> list[TextualMemoryItem]:
166167
"""Search for memories based on a query.
@@ -207,6 +208,7 @@ def search(
207208
user_name=user_name,
208209
search_tool_memory=search_tool_memory,
209210
tool_mem_top_k=tool_mem_top_k,
211+
dedup=dedup,
210212
**kwargs,
211213
)
212214

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def deep_search(
239239
user_name: str | None = None,
240240
**kwargs,
241241
):
242+
dedup = kwargs.get("dedup")
242243
previous_retrieval_phrases = [query]
243244
retrieved_memories = self.retrieve(
244245
query=query,
@@ -254,6 +255,7 @@ def deep_search(
254255
top_k=top_k,
255256
user_name=user_name,
256257
info=info,
258+
dedup=dedup,
257259
)
258260
if len(memories) == 0:
259261
logger.warning("Requirements not met; returning memories as-is.")

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,15 @@ def post_retrieve(
119119
info=None,
120120
search_tool_memory: bool = False,
121121
tool_mem_top_k: int = 6,
122+
dedup: str | None = None,
122123
plugin=False,
123124
):
124-
deduped = self._deduplicate_results(retrieved_results)
125+
if dedup == "no":
126+
deduped = retrieved_results
127+
elif dedup == "sim":
128+
deduped = self._deduplicate_similar_results(retrieved_results)
129+
else:
130+
deduped = self._deduplicate_results(retrieved_results)
125131
final_results = self._sort_and_trim(
126132
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
127133
)
@@ -141,6 +147,7 @@ def search(
141147
user_name: str | None = None,
142148
search_tool_memory: bool = False,
143149
tool_mem_top_k: int = 6,
150+
dedup: str | None = None,
144151
**kwargs,
145152
) -> list[TextualMemoryItem]:
146153
"""
@@ -173,7 +180,11 @@ def search(
173180
if kwargs.get("plugin", False):
174181
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
175182
retrieved_results = self._retrieve_simple(
176-
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
183+
query=query,
184+
top_k=top_k,
185+
search_filter=search_filter,
186+
user_name=user_name,
187+
dedup=dedup,
177188
)
178189
else:
179190
retrieved_results = self.retrieve(
@@ -202,6 +213,7 @@ def search(
202213
plugin=kwargs.get("plugin", False),
203214
search_tool_memory=search_tool_memory,
204215
tool_mem_top_k=tool_mem_top_k,
216+
dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup,
205217
)
206218

207219
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -291,6 +303,7 @@ def _retrieve_simple(
291303
top_k: int,
292304
search_filter: dict | None = None,
293305
user_name: str | None = None,
306+
dedup: str | None = None,
294307
**kwargs,
295308
):
296309
"""Retrieve from by keywords and embedding"""
@@ -710,14 +723,17 @@ def _retrieve_simple(
710723
user_name=user_name,
711724
)
712725
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
713-
documents = [getattr(item, "memory", "") for item in items]
714-
documents_embeddings = self.embedder.embed(documents)
715-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
716-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
717-
selected_items = [items[i] for i in selected_indices]
718-
logger.info(
719-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
720-
)
726+
if dedup == "no":
727+
selected_items = items
728+
else:
729+
documents = [getattr(item, "memory", "") for item in items]
730+
documents_embeddings = self.embedder.embed(documents)
731+
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
732+
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
733+
selected_items = [items[i] for i in selected_indices]
734+
logger.info(
735+
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
736+
)
721737
return self.reranker.rerank(
722738
query=query,
723739
query_embedding=query_embeddings[0],
@@ -734,6 +750,26 @@ def _deduplicate_results(self, results):
734750
deduped[item.memory] = (item, score)
735751
return list(deduped.values())
736752

753+
@timed
754+
def _deduplicate_similar_results(
755+
self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85
756+
):
757+
"""Deduplicate results by semantic similarity while keeping higher scores."""
758+
if len(results) <= 1:
759+
return results
760+
761+
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)
762+
documents = [getattr(item, "memory", "") for item, _ in sorted_results]
763+
embeddings = self.embedder.embed(documents)
764+
similarity_matrix = cosine_similarity_matrix(embeddings)
765+
766+
selected_indices: list[int] = []
767+
for i in range(len(sorted_results)):
768+
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769+
selected_indices.append(i)
770+
771+
return [sorted_results[i] for i in selected_indices]
772+
737773
@timed
738774
def _sort_and_trim(
739775
self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6

src/memos/multi_mem_cube/single_cube.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
MEM_READ_TASK_LABEL,
2424
PREF_ADD_TASK_LABEL,
2525
)
26+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
27+
cosine_similarity_matrix,
28+
)
2629
from memos.multi_mem_cube.views import MemCubeView
2730
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
2831
from memos.types.general_types import (
@@ -263,6 +266,7 @@ def _deep_search(
263266
moscube=search_req.moscube,
264267
search_filter=search_filter,
265268
info=info,
269+
dedup=search_req.dedup,
266270
)
267271
formatted_memories = [format_memory_item(data) for data in enhanced_memories]
268272
return formatted_memories
@@ -328,6 +332,7 @@ def _fine_search(
328332
top_k=search_req.top_k,
329333
user_name=user_context.mem_cube_id,
330334
info=info,
335+
dedup=search_req.dedup,
331336
)
332337

333338
# Enhance with query
@@ -378,7 +383,24 @@ def _dedup_by_content(memories: list) -> list:
378383
unique_memories.append(mem)
379384
return unique_memories
380385

381-
deduped_memories = _dedup_by_content(enhanced_memories)
386+
def _dedup_by_similarity(memories: list) -> list:
387+
if len(memories) <= 1:
388+
return memories
389+
documents = [getattr(mem, "memory", "") for mem in memories]
390+
embeddings = self.searcher.embedder.embed(documents)
391+
similarity_matrix = cosine_similarity_matrix(embeddings)
392+
selected_indices = []
393+
for i in range(len(memories)):
394+
if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices):
395+
selected_indices.append(i)
396+
return [memories[i] for i in selected_indices]
397+
398+
if search_req.dedup == "no":
399+
deduped_memories = enhanced_memories
400+
elif search_req.dedup == "sim":
401+
deduped_memories = _dedup_by_similarity(enhanced_memories)
402+
else:
403+
deduped_memories = _dedup_by_content(enhanced_memories)
382404
formatted_memories = [format_memory_item(data) for data in deduped_memories]
383405

384406
logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}")
@@ -463,6 +485,7 @@ def _fast_search(
463485
plugin=plugin,
464486
search_tool_memory=search_req.search_tool_memory,
465487
tool_mem_top_k=search_req.tool_mem_top_k,
488+
dedup=search_req.dedup,
466489
)
467490

468491
formatted_memories = [format_memory_item(data) for data in search_results]

0 commit comments

Comments
 (0)