Skip to content

Commit 9273c66

Browse files
Fix dedup handling in simple search
1 parent ce297ba commit 9273c66

File tree

10 files changed

+107
-140
lines changed

10 files changed

+107
-140
lines changed

src/memos/api/handlers/formatters_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def to_iter(running: Any) -> list[Any]:
2929
return list(running) if running else []
3030

3131

32-
def format_memory_item(memory_data: Any) -> dict[str, Any]:
32+
def format_memory_item(memory_data: Any, include_embedding: bool = False) -> dict[str, Any]:
3333
"""
3434
Format a single memory item for API response.
3535
@@ -47,7 +47,8 @@ def format_memory_item(memory_data: Any) -> dict[str, Any]:
4747
ref_id = f"[{memory_id.split('-')[0]}]"
4848

4949
memory["ref_id"] = ref_id
50-
memory["metadata"]["embedding"] = []
50+
if not include_embedding:
51+
memory["metadata"]["embedding"] = []
5152
memory["metadata"]["sources"] = []
5253
memory["metadata"]["usage"] = []
5354
memory["metadata"]["ref_id"] = ref_id

src/memos/api/handlers/search_handler.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from typing import Any
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APISearchRequest, SearchResponse
1012
from memos.log import get_logger
13+
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
14+
cosine_similarity_matrix,
15+
)
1116
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1217
from memos.multi_mem_cube.single_cube import SingleCubeView
1318
from memos.multi_mem_cube.views import MemCubeView
@@ -53,6 +58,9 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5358
cube_view = self._build_cube_view(search_req)
5459

5560
results = cube_view.search_memories(search_req)
61+
if search_req.dedup == "sim":
62+
results = self._dedup_text_memories(results, search_req.top_k)
63+
self._strip_embeddings(results)
5664

5765
self.logger.info(
5866
f"[SearchHandler] Final search results: count={len(results)} results={results}"
@@ -63,6 +71,61 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
6371
data=results,
6472
)
6573

74+
def _dedup_text_memories(self, results: dict[str, Any], target_top_k: int) -> dict[str, Any]:
75+
for bucket in results.get("text_mem", []):
76+
memories = bucket.get("memories", [])
77+
if len(memories) <= 1:
78+
continue
79+
embeddings = self._extract_embeddings(memories)
80+
if embeddings is None:
81+
documents = [mem.get("memory", "") for mem in memories]
82+
embeddings = self.searcher.embedder.embed(documents)
83+
similarity_matrix = cosine_similarity_matrix(embeddings)
84+
selected_indices = self._select_unrelated_indices(similarity_matrix, 0.85)
85+
if len(selected_indices) < min(target_top_k, len(memories)):
86+
selected = set(selected_indices)
87+
for i in range(len(memories)):
88+
if i in selected:
89+
continue
90+
selected_indices.append(i)
91+
if len(selected_indices) >= target_top_k:
92+
break
93+
bucket["memories"] = [memories[i] for i in selected_indices[:target_top_k]]
94+
return results
95+
96+
@staticmethod
97+
def _select_unrelated_indices(
98+
similarity_matrix: list[list[float]], similarity_threshold: float
99+
) -> list[int]:
100+
selected_indices: list[int] = []
101+
for i in range(len(similarity_matrix)):
102+
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
103+
selected_indices.append(i)
104+
return selected_indices
105+
106+
@staticmethod
107+
def _extract_embeddings(memories: list[dict[str, Any]]) -> list[list[float]] | None:
108+
embeddings: list[list[float]] = []
109+
for mem in memories:
110+
embedding = mem.get("metadata", {}).get("embedding")
111+
if not embedding:
112+
return None
113+
embeddings.append(embedding)
114+
return embeddings
115+
116+
@staticmethod
117+
def _strip_embeddings(results: dict[str, Any]) -> None:
118+
for bucket in results.get("text_mem", []):
119+
for mem in bucket.get("memories", []):
120+
metadata = mem.get("metadata", {})
121+
if "embedding" in metadata:
122+
metadata["embedding"] = []
123+
for bucket in results.get("tool_mem", []):
124+
for mem in bucket.get("memories", []):
125+
metadata = mem.get("metadata", {})
126+
if "embedding" in metadata:
127+
metadata["embedding"] = []
128+
66129
def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
67130
"""
68131
Normalize target cube ids from search_req.

src/memos/api/start_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from typing import Any, Generic, Literal, TypeVar
4+
from typing import Any, Generic, TypeVar
55

66
from dotenv import load_dotenv
77
from fastapi import FastAPI
@@ -145,14 +145,6 @@ class SearchRequest(BaseRequest):
145145
description="List of cube IDs to search in",
146146
json_schema_extra={"example": ["cube123", "cube456"]},
147147
)
148-
dedup: Literal["no", "sim"] | None = Field(
149-
None,
150-
description=(
151-
"Optional dedup option for textual memories. "
152-
"Use 'no' for no dedup, 'sim' for similarity dedup. "
153-
"If None, default exact-text dedup is applied."
154-
),
155-
)
156148

157149

158150
class MemCubeRegister(BaseRequest):
@@ -357,7 +349,6 @@ async def search_memories(search_req: SearchRequest):
357349
query=search_req.query,
358350
user_id=search_req.user_id,
359351
install_cube_ids=search_req.install_cube_ids,
360-
dedup=search_req.dedup,
361352
)
362353
return SearchResponse(message="Search completed successfully", data=result)
363354

src/memos/mem_os/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ def search(
551551
internet_search: bool = False,
552552
moscube: bool = False,
553553
session_id: str | None = None,
554-
dedup: str | None = None,
555554
**kwargs,
556555
) -> MOSSearchResult:
557556
"""
@@ -626,7 +625,6 @@ def search_textual_memory(cube_id, cube):
626625
},
627626
moscube=moscube,
628627
search_filter=search_filter,
629-
dedup=dedup,
630628
)
631629
search_time_end = time.time()
632630
logger.info(

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,13 @@ def mix_search_memories(
186186
info=info,
187187
search_tool_memory=search_req.search_tool_memory,
188188
tool_mem_top_k=search_req.tool_mem_top_k,
189-
dedup=search_req.dedup,
190189
)
191190
memories = merged_memories[: search_req.top_k]
192191

193-
formatted_memories = [format_textual_memory_item(item) for item in memories]
192+
formatted_memories = [
193+
format_textual_memory_item(item, include_embedding=search_req.dedup == "sim")
194+
for item in memories
195+
]
194196
self.submit_memory_history_async_task(
195197
search_req=search_req,
196198
user_context=user_context,
@@ -234,7 +236,10 @@ def update_search_memories_to_redis(
234236
mem_cube=self.mem_cube,
235237
mode=SearchMode.FAST,
236238
)
237-
formatted_memories = [format_textual_memory_item(data) for data in memories]
239+
formatted_memories = [
240+
format_textual_memory_item(data, include_embedding=search_req.dedup == "sim")
241+
for data in memories
242+
]
238243
else:
239244
memories = [
240245
TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]

src/memos/mem_scheduler/utils/api_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
from memos.memories.textual.tree import TextualMemoryItem
77

88

9-
def format_textual_memory_item(memory_data: Any) -> dict[str, Any]:
9+
def format_textual_memory_item(
10+
memory_data: Any, include_embedding: bool = False
11+
) -> dict[str, Any]:
1012
"""Format a single memory item for API response."""
1113
memory = memory_data.model_dump()
1214
memory_id = memory["id"]
1315
ref_id = f"[{memory_id.split('-')[0]}]"
1416

1517
memory["ref_id"] = ref_id
16-
memory["metadata"]["embedding"] = []
18+
if not include_embedding:
19+
memory["metadata"]["embedding"] = []
1720
memory["metadata"]["sources"] = []
1821
memory["metadata"]["ref_id"] = ref_id
1922
memory["metadata"]["id"] = memory_id

src/memos/memories/textual/tree.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def search(
161161
user_name: str | None = None,
162162
search_tool_memory: bool = False,
163163
tool_mem_top_k: int = 6,
164-
dedup: str | None = None,
165164
**kwargs,
166165
) -> list[TextualMemoryItem]:
167166
"""Search for memories based on a query.
@@ -208,7 +207,6 @@ def search(
208207
user_name=user_name,
209208
search_tool_memory=search_tool_memory,
210209
tool_mem_top_k=tool_mem_top_k,
211-
dedup=dedup,
212210
**kwargs,
213211
)
214212

src/memos/memories/textual/tree_text_memory/retrieve/advanced_searcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def deep_search(
239239
user_name: str | None = None,
240240
**kwargs,
241241
):
242-
dedup = kwargs.get("dedup")
243242
previous_retrieval_phrases = [query]
244243
retrieved_memories = self.retrieve(
245244
query=query,
@@ -255,7 +254,6 @@ def deep_search(
255254
top_k=top_k,
256255
user_name=user_name,
257256
info=info,
258-
dedup=dedup,
259257
)
260258
if len(memories) == 0:
261259
logger.warning("Requirements not met; returning memories as-is.")

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 10 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,9 @@ def post_retrieve(
119119
info=None,
120120
search_tool_memory: bool = False,
121121
tool_mem_top_k: int = 6,
122-
dedup: str | None = None,
123122
plugin=False,
124123
):
125-
if dedup == "no":
126-
deduped = retrieved_results
127-
elif dedup == "sim":
128-
deduped = self._deduplicate_similar_results(retrieved_results)
129-
else:
130-
deduped = self._deduplicate_results(retrieved_results)
124+
deduped = self._deduplicate_results(retrieved_results)
131125
final_results = self._sort_and_trim(
132126
deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
133127
)
@@ -147,7 +141,6 @@ def search(
147141
user_name: str | None = None,
148142
search_tool_memory: bool = False,
149143
tool_mem_top_k: int = 6,
150-
dedup: str | None = None,
151144
**kwargs,
152145
) -> list[TextualMemoryItem]:
153146
"""
@@ -180,11 +173,7 @@ def search(
180173
if kwargs.get("plugin", False):
181174
logger.info(f"[SEARCH] Retrieve from plugin: {query}")
182175
retrieved_results = self._retrieve_simple(
183-
query=query,
184-
top_k=top_k,
185-
search_filter=search_filter,
186-
user_name=user_name,
187-
dedup=dedup,
176+
query=query, top_k=top_k, search_filter=search_filter, user_name=user_name
188177
)
189178
else:
190179
retrieved_results = self.retrieve(
@@ -213,7 +202,6 @@ def search(
213202
plugin=kwargs.get("plugin", False),
214203
search_tool_memory=search_tool_memory,
215204
tool_mem_top_k=tool_mem_top_k,
216-
dedup=None if kwargs.get("plugin", False) and dedup == "sim" else dedup,
217205
)
218206

219207
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -296,50 +284,6 @@ def _parse_task(
296284

297285
return parsed_goal, query_embedding, context, query
298286

299-
@timed
300-
def _retrieve_simple(
301-
self,
302-
query: str,
303-
top_k: int,
304-
search_filter: dict | None = None,
305-
user_name: str | None = None,
306-
dedup: str | None = None,
307-
**kwargs,
308-
):
309-
"""Retrieve from by keywords and embedding"""
310-
query_words = []
311-
if self.tokenizer:
312-
query_words = self.tokenizer.tokenize_mixed(query)
313-
else:
314-
query_words = query.strip().split()
315-
query_words = [query, *query_words]
316-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
317-
query_embeddings = self.embedder.embed(query_words)
318-
319-
items = self.graph_retriever.retrieve_from_mixed(
320-
top_k=top_k * 2,
321-
memory_scope=None,
322-
query_embedding=query_embeddings,
323-
search_filter=search_filter,
324-
user_name=user_name,
325-
use_fast_graph=self.use_fast_graph,
326-
)
327-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
328-
documents = [getattr(item, "memory", "") for item in items]
329-
documents_embeddings = self.embedder.embed(documents)
330-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
331-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
332-
selected_items = [items[i] for i in selected_indices]
333-
logger.info(
334-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
335-
)
336-
return self.reranker.rerank(
337-
query=query,
338-
query_embedding=query_embeddings[0],
339-
graph_results=selected_items,
340-
top_k=top_k,
341-
)
342-
343287
@timed
344288
def _retrieve_paths(
345289
self,
@@ -723,17 +667,14 @@ def _retrieve_simple(
723667
user_name=user_name,
724668
)
725669
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
726-
if dedup == "no":
727-
selected_items = items
728-
else:
729-
documents = [getattr(item, "memory", "") for item in items]
730-
documents_embeddings = self.embedder.embed(documents)
731-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
732-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
733-
selected_items = [items[i] for i in selected_indices]
734-
logger.info(
735-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
736-
)
670+
documents = [getattr(item, "memory", "") for item in items]
671+
documents_embeddings = self.embedder.embed(documents)
672+
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
673+
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
674+
selected_items = [items[i] for i in selected_indices]
675+
logger.info(
676+
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
677+
)
737678
return self.reranker.rerank(
738679
query=query,
739680
query_embedding=query_embeddings[0],
@@ -750,26 +691,6 @@ def _deduplicate_results(self, results):
750691
deduped[item.memory] = (item, score)
751692
return list(deduped.values())
752693

753-
@timed
754-
def _deduplicate_similar_results(
755-
self, results: list[tuple[TextualMemoryItem, float]], similarity_threshold: float = 0.85
756-
):
757-
"""Deduplicate results by semantic similarity while keeping higher scores."""
758-
if len(results) <= 1:
759-
return results
760-
761-
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)
762-
documents = [getattr(item, "memory", "") for item, _ in sorted_results]
763-
embeddings = self.embedder.embed(documents)
764-
similarity_matrix = cosine_similarity_matrix(embeddings)
765-
766-
selected_indices: list[int] = []
767-
for i in range(len(sorted_results)):
768-
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769-
selected_indices.append(i)
770-
771-
return [sorted_results[i] for i in selected_indices]
772-
773694
@timed
774695
def _sort_and_trim(
775696
self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6

0 commit comments

Comments
 (0)