Skip to content

Commit d734029

Browse files
Fix dedup handling in simple search
1 parent ce297ba commit d734029

File tree

8 files changed

+61
-130
lines changed

8 files changed

+61
-130
lines changed

src/memos/api/handlers/search_handler.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from typing import Any
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APISearchRequest, SearchResponse
1012
from memos.log import get_logger
13+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
14+
cosine_similarity_matrix,
15+
)
1116
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1217
from memos.multi_mem_cube.single_cube import SingleCubeView
1318
from memos.multi_mem_cube.views import MemCubeView
@@ -53,6 +58,8 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5358
cube_view = self._build_cube_view(search_req)
5459

5560
results = cube_view.search_memories(search_req)
61+
if search_req.dedup == "sim":
62+
results = self._dedup_text_memories(results, search_req.top_k)
5663

5764
self.logger.info(
5865
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -63,6 +70,48 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
6370
data=results,
6471
)
6572

73+
def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
74+
for bucket in results.get("text_mem", []):
75+
memories = bucket.get("memories", [])
76+
if len(memories) <= 1:
77+
continue
78+
embeddings = self._extract_embeddings(memories)
79+
if embeddings is None:
80+
documents = [mem.get("memory", "") for mem in memories]
81+
embeddings = self.searcher.embedder.embed(documents)
82+
similarity_matrix = cosine_similarity_matrix(embeddings)
83+
selected_indices = self._select_unrelated_indices(similarity_matrix, 0.85)
84+
if len(selected_indices) < min(target_top_k, len(memories)):
85+
selected = set(selected_indices)
86+
for i in range(len(memories)):
87+
if i in selected:
88+
continue
89+
selected_indices.append(i)
90+
if len(selected_indices) >= target_top_k:
91+
break
92+
bucket["memories"] = [memories[i] for i in selected_indices[:target_top_k]]
93+
return results
94+
95+
@staticmethod
96+
def _select_unrelated_indices(
97+
similarity_matrix: list[list[float]], similarity_threshold: float
98+
) -> list[int]:
99+
selected_indices: list[int] = []
100+
for i in range(len(similarity_matrix)):
101+
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
102+
selected_indices.append(i)
103+
return selected_indices
104+
105+
@staticmethod
106+
def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
107+
embeddings: list[list[float]] = []
108+
for mem in memories:
109+
embedding = mem.get("metadata", {}).get("embedding")
110+
if not embedding:
111+
return None
112+
embeddings.append(embedding)
113+
return embeddings
114+
66115
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
67116
"""
68117
Normalize target cube ids from search_req.

src/memos/api/start_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from typing import Any, Generic, Literal, TypeVar
4+
from typing import Any, Generic, TypeVar
55

66
from dotenv import load_dotenv
77
from fastapi import FastAPI
@@ -145,14 +145,6 @@ class SearchRequest(BaseRequest):
145145
description="List of cube IDs to search in",
146146
json_schema_extra={"example": ["cube123", "cube456"]},
147147
)
148-
dedup: Literal["no", "sim"] | None = Field(
149-
None,
150-
description=(
151-
"Optional dedup option for textual memories. "
152-
"Use 'no' for no dedup, 'sim' for similarity dedup. "
153-
"If None, default exact-text dedup is applied."
154-
),
155-
)
156148

157149

158150
class MemCubeRegister(BaseRequest):
@@ -357,7 +349,6 @@ async def search_memories(search_req: SearchRequest):
357349
query=search_req.query,
358350
user_id=search_req.user_id,
359351
install_cube_ids=search_req.install_cube_ids,
360-
dedup=search_req.dedup,
361352
)
362353
return SearchResponse(message="Search completed successfully", data=result)
363354

src/memos/mem_os/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ def search(
551551
internet_search: bool = False,
552552
moscube: bool = False,
553553
session_id: str | None = None,
554-
dedup: str | None = None,
555554
**kwargs,
556555
) -> MOSSearchResult:
557556
"""
@@ -626,7 +625,6 @@ def search_textual_memory(cube_id, cube):
626625
},
627626
moscube=moscube,
628627
search_filter=search_filter,
629-
dedup=dedup,
630628
)
631629
search_time_end = time.time()
632630
logger.info(

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def mix_search_memories(
186186
info=info,
187187
search_tool_memory=search_req.search_tool_memory,
188188
tool_mem_top_k=search_req.tool_mem_top_k,
189-
dedup=search_req.dedup,
190189
)
191190
memories = merged_memories[: search_req.top_k]
192191

src/memos/memories/textual/tree.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def search(
161161
user_name: str | None = None,
162162
search_tool_memory: bool = False,
163163
tool_mem_top_k: int = 6,
164-
dedup: str | None = None,
165164
**kwargs,
166165
) -> list[TextualMemoryItem]:
167166
"""Search for memories based on a query.
@@ -208,7 +207,6 @@ def search(
208207
user_name=user_name,
209208
search_tool_memory=search_tool_memory,
210209
tool_mem_top_k=tool_mem_top_k,
211-
dedup=dedup,
212210
**kwargs,
213211
)
214212

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def deep_search(
239239
user_name: str | None = None,
240240
**kwargs,
241241
):
242-
dedup = kwargs.get("dedup")
243242
previous_retrieval_phrases = [query]
244243
retrieved_memories = self.retrieve(
245244
query=query,
@@ -255,7 +254,6 @@ def deep_search(
255254
top_k=top_k,
256255
user_name=user_name,
257256
info=info,
258-
dedup=dedup,
259257
)
260258
if len(memories) == 0:
261259
logger.warning("Requirements not met; returning memories as-is.")

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 10 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,9 @@ def post_retrieve(
119119
info=None,
120120
search_tool_memory: bool = False,
121121
tool_mem_top_k: int = 6,
122-
dedup: str | None = None,
123122
plugin=False,
124123
):
125-
if dedup == "no":
126-
deduped = retrieved_results
127-
elif dedup == "sim":
128-
deduped = self._deduplicate_similar_results(retrieved_results)
129-
else:
130-
deduped = self._deduplicate_results(retrieved_results)
124+
deduped = self._deduplicate_results(retrieved_results)
131125
final_results = self._sort_and_trim(
132126
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
133127
)
@@ -147,7 +141,6 @@ def search(
147141
user_name: str | None = None,
148142
search_tool_memory: bool = False,
149143
tool_mem_top_k: int = 6,
150-
dedup: str | None = None,
151144
**kwargs,
152145
) -> list[TextualMemoryItem]:
153146
"""
@@ -180,11 +173,7 @@ def search(
180173
if kwargs.get("plugin", False):
181174
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
182175
retrieved_results = self._retrieve_simple(
183-
query=query,
184-
top_k=top_k,
185-
search_filter=search_filter,
186-
user_name=user_name,
187-
dedup=dedup,
176+
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
188177
)
189178
else:
190179
retrieved_results = self.retrieve(
@@ -213,7 +202,6 @@ def search(
213202
plugin=kwargs.get("plugin", False),
214203
search_tool_memory=search_tool_memory,
215204
tool_mem_top_k=tool_mem_top_k,
216-
dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup,
217205
)
218206

219207
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -296,50 +284,6 @@ def _parse_task(
296284

297285
return parsed_goal, query_embedding, context, query
298286

299-
@timed
300-
def _retrieve_simple(
301-
self,
302-
query: str,
303-
top_k: int,
304-
search_filter: dict | None = None,
305-
user_name: str | None = None,
306-
dedup: str | None = None,
307-
**kwargs,
308-
):
309-
"""Retrieve from by keywords and embedding"""
310-
query_words = []
311-
if self.tokenizer:
312-
query_words = self.tokenizer.tokenize_mixed(query)
313-
else:
314-
query_words = query.strip().split()
315-
query_words = [query, *query_words]
316-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
317-
query_embeddings = self.embedder.embed(query_words)
318-
319-
items = self.graph_retriever.retrieve_from_mixed(
320-
top_k=top_k * 2,
321-
memory_scope=None,
322-
query_embedding=query_embeddings,
323-
search_filter=search_filter,
324-
user_name=user_name,
325-
use_fast_graph=self.use_fast_graph,
326-
)
327-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
328-
documents = [getattr(item, "memory", "") for item in items]
329-
documents_embeddings = self.embedder.embed(documents)
330-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
331-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
332-
selected_items = [items[i] for i in selected_indices]
333-
logger.info(
334-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
335-
)
336-
return self.reranker.rerank(
337-
query=query,
338-
query_embedding=query_embeddings[0],
339-
graph_results=selected_items,
340-
top_k=top_k,
341-
)
342-
343287
@timed
344288
def _retrieve_paths(
345289
self,
@@ -723,17 +667,14 @@ def _retrieve_simple(
723667
user_name=user_name,
724668
)
725669
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
726-
if dedup == "no":
727-
selected_items = items
728-
else:
729-
documents = [getattr(item, "memory", "") for item in items]
730-
documents_embeddings = self.embedder.embed(documents)
731-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
732-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
733-
selected_items = [items[i] for i in selected_indices]
734-
logger.info(
735-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
736-
)
670+
documents = [getattr(item, "memory", "") for item in items]
671+
documents_embeddings = self.embedder.embed(documents)
672+
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
673+
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
674+
selected_items = [items[i] for i in selected_indices]
675+
logger.info(
676+
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
677+
)
737678
return self.reranker.rerank(
738679
query=query,
739680
query_embedding=query_embeddings[0],
@@ -750,26 +691,6 @@ def _deduplicate_results(self, results):
750691
deduped[item.memory] = (item, score)
751692
return list(deduped.values())
752693

753-
@timed
754-
def _deduplicate_similar_results(
755-
self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85
756-
):
757-
"""Deduplicate results by semantic similarity while keeping higher scores."""
758-
if len(results) <= 1:
759-
return results
760-
761-
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)
762-
documents = [getattr(item, "memory", "") for item, _ in sorted_results]
763-
embeddings = self.embedder.embed(documents)
764-
similarity_matrix = cosine_similarity_matrix(embeddings)
765-
766-
selected_indices: list[int] = []
767-
for i in range(len(sorted_results)):
768-
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769-
selected_indices.append(i)
770-
771-
return [sorted_results[i] for i in selected_indices]
772-
773694
@timed
774695
def _sort_and_trim(
775696
self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6

src/memos/multi_mem_cube/single_cube.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
MEM_READ_TASK_LABEL,
2424
PREF_ADD_TASK_LABEL,
2525
)
26-
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
27-
cosine_similarity_matrix,
28-
)
2926
from memos.multi_mem_cube.views import MemCubeView
3027
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
3128
from memos.types.general_types import (
@@ -266,7 +263,6 @@ def _deep_search(
266263
moscube=search_req.moscube,
267264
search_filter=search_filter,
268265
info=info,
269-
dedup=search_req.dedup,
270266
)
271267
formatted_memories = [format_memory_item(data) for data in enhanced_memories]
272268
return formatted_memories
@@ -332,7 +328,6 @@ def _fine_search(
332328
top_k=search_req.top_k,
333329
user_name=user_context.mem_cube_id,
334330
info=info,
335-
dedup=search_req.dedup,
336331
)
337332

338333
# Enhance with query
@@ -383,24 +378,7 @@ def _dedup_by_content(memories: list) -> list:
383378
unique_memories.append(mem)
384379
return unique_memories
385380

386-
def _dedup_by_similarity(memories: list) -> list:
387-
if len(memories) <= 1:
388-
return memories
389-
documents = [getattr(mem, "memory", "") for mem in memories]
390-
embeddings = self.searcher.embedder.embed(documents)
391-
similarity_matrix = cosine_similarity_matrix(embeddings)
392-
selected_indices = []
393-
for i in range(len(memories)):
394-
if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices):
395-
selected_indices.append(i)
396-
return [memories[i] for i in selected_indices]
397-
398-
if search_req.dedup == "no":
399-
deduped_memories = enhanced_memories
400-
elif search_req.dedup == "sim":
401-
deduped_memories = _dedup_by_similarity(enhanced_memories)
402-
else:
403-
deduped_memories = _dedup_by_content(enhanced_memories)
381+
deduped_memories = _dedup_by_content(enhanced_memories)
404382
formatted_memories = [format_memory_item(data) for data in deduped_memories]
405383

406384
logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}")
@@ -485,7 +463,6 @@ def _fast_search(
485463
plugin=plugin,
486464
search_tool_memory=search_req.search_tool_memory,
487465
tool_mem_top_k=search_req.tool_mem_top_k,
488-
dedup=search_req.dedup,
489466
)
490467

491468
formatted_memories = [format_memory_item(data) for data in search_results]

0 commit comments

Comments
 (0)