Skip to content

Commit 92a9289

Browse files
Fix dedup handling in simple search
1 parent ce297ba commit 92a9289

File tree

4 files changed

+34
-75
lines changed

4 files changed

+34
-75
lines changed

src/memos/api/start_api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from typing import Any, Generic, Literal, TypeVar
4+
from typing import Any, Generic, TypeVar
55

66
from dotenv import load_dotenv
77
from fastapi import FastAPI
@@ -145,14 +145,6 @@ class SearchRequest(BaseRequest):
145145
description="List of cube IDs to search in",
146146
json_schema_extra={"example": ["cube123", "cube456"]},
147147
)
148-
dedup: Literal["no", "sim"] | None = Field(
149-
None,
150-
description=(
151-
"Optional dedup option for textual memories. "
152-
"Use 'no' for no dedup, 'sim' for similarity dedup. "
153-
"If None, default exact-text dedup is applied."
154-
),
155-
)
156148

157149

158150
class MemCubeRegister(BaseRequest):
@@ -357,7 +349,6 @@ async def search_memories(search_req: SearchRequest):
357349
query=search_req.query,
358350
user_id=search_req.user_id,
359351
install_cube_ids=search_req.install_cube_ids,
360-
dedup=search_req.dedup,
361352
)
362353
return SearchResponse(message="Search completed successfully", data=result)
363354

src/memos/mem_os/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ def search(
551551
internet_search: bool = False,
552552
moscube: bool = False,
553553
session_id: str | None = None,
554-
dedup: str | None = None,
555554
**kwargs,
556555
) -> MOSSearchResult:
557556
"""
@@ -626,7 +625,6 @@ def search_textual_memory(cube_id, cube):
626625
},
627626
moscube=moscube,
628627
search_filter=search_filter,
629-
dedup=dedup,
630628
)
631629
search_time_end = time.time()
632630
logger.info(

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -296,50 +296,6 @@ def _parse_task(
296296

297297
return parsed_goal, query_embedding, context, query
298298

299-
@timed
300-
def _retrieve_simple(
301-
self,
302-
query: str,
303-
top_k: int,
304-
search_filter: dict | None = None,
305-
user_name: str | None = None,
306-
dedup: str | None = None,
307-
**kwargs,
308-
):
309-
"""Retrieve from by keywords and embedding"""
310-
query_words = []
311-
if self.tokenizer:
312-
query_words = self.tokenizer.tokenize_mixed(query)
313-
else:
314-
query_words = query.strip().split()
315-
query_words = [query, *query_words]
316-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
317-
query_embeddings = self.embedder.embed(query_words)
318-
319-
items = self.graph_retriever.retrieve_from_mixed(
320-
top_k=top_k * 2,
321-
memory_scope=None,
322-
query_embedding=query_embeddings,
323-
search_filter=search_filter,
324-
user_name=user_name,
325-
use_fast_graph=self.use_fast_graph,
326-
)
327-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
328-
documents = [getattr(item, "memory", "") for item in items]
329-
documents_embeddings = self.embedder.embed(documents)
330-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
331-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
332-
selected_items = [items[i] for i in selected_indices]
333-
logger.info(
334-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
335-
)
336-
return self.reranker.rerank(
337-
query=query,
338-
query_embedding=query_embeddings[0],
339-
graph_results=selected_items,
340-
top_k=top_k,
341-
)
342-
343299
@timed
344300
def _retrieve_paths(
345301
self,
@@ -699,6 +655,7 @@ def _retrieve_simple(
699655
top_k: int,
700656
search_filter: dict | None = None,
701657
user_name: str | None = None,
658+
dedup: str | None = None,
702659
**kwargs,
703660
):
704661
"""
@@ -721,10 +678,16 @@ def _retrieve_simple(
721678
query_embedding=query_embeddings,
722679
search_filter=search_filter,
723680
user_name=user_name,
681+
use_fast_graph=self.use_fast_graph,
724682
)
725683
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
726684
if dedup == "no":
727685
selected_items = items
686+
elif dedup == "sim":
687+
selected_items = self.deduplicate_similar_items(items)
688+
logger.info(
689+
f"[SIMPLESEARCH] after similarity dedup items count: {len(selected_items)}"
690+
)
728691
else:
729692
documents = [getattr(item, "memory", "") for item in items]
730693
documents_embeddings = self.embedder.embed(documents)
@@ -763,12 +726,34 @@ def _deduplicate_similar_results(
763726
embeddings = self.embedder.embed(documents)
764727
similarity_matrix = cosine_similarity_matrix(embeddings)
765728

729+
selected_indices = self._select_unrelated_indices(
730+
similarity_matrix, similarity_threshold
731+
)
732+
return [sorted_results[i] for i in selected_indices]
733+
734+
def deduplicate_similar_items(
735+
self, items: list[TextualMemoryItem], similarity_threshold: float = 0.85
736+
) -> list[TextualMemoryItem]:
737+
"""Deduplicate memory items by semantic similarity while preserving order."""
738+
if len(items) <= 1:
739+
return items
740+
documents = [getattr(item, "memory", "") for item in items]
741+
embeddings = self.embedder.embed(documents)
742+
similarity_matrix = cosine_similarity_matrix(embeddings)
743+
selected_indices = self._select_unrelated_indices(
744+
similarity_matrix, similarity_threshold
745+
)
746+
return [items[i] for i in selected_indices]
747+
748+
@staticmethod
749+
def _select_unrelated_indices(
750+
similarity_matrix: list[list[float]], similarity_threshold: float
751+
) -> list[int]:
766752
selected_indices: list[int] = []
767-
for i in range(len(sorted_results)):
753+
for i in range(len(similarity_matrix)):
768754
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769755
selected_indices.append(i)
770-
771-
return [sorted_results[i] for i in selected_indices]
756+
return selected_indices
772757

773758
@timed
774759
def _sort_and_trim(

src/memos/multi_mem_cube/single_cube.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
MEM_READ_TASK_LABEL,
2424
PREF_ADD_TASK_LABEL,
2525
)
26-
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
27-
cosine_similarity_matrix,
28-
)
2926
from memos.multi_mem_cube.views import MemCubeView
3027
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
3128
from memos.types.general_types import (
@@ -383,22 +380,10 @@ def _dedup_by_content(memories: list) -> list:
383380
unique_memories.append(mem)
384381
return unique_memories
385382

386-
def _dedup_by_similarity(memories: list) -> list:
387-
if len(memories) <= 1:
388-
return memories
389-
documents = [getattr(mem, "memory", "") for mem in memories]
390-
embeddings = self.searcher.embedder.embed(documents)
391-
similarity_matrix = cosine_similarity_matrix(embeddings)
392-
selected_indices = []
393-
for i in range(len(memories)):
394-
if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices):
395-
selected_indices.append(i)
396-
return [memories[i] for i in selected_indices]
397-
398383
if search_req.dedup == "no":
399384
deduped_memories = enhanced_memories
400385
elif search_req.dedup == "sim":
401-
deduped_memories = _dedup_by_similarity(enhanced_memories)
386+
deduped_memories = self.searcher.deduplicate_similar_items(enhanced_memories)
402387
else:
403388
deduped_memories = _dedup_by_content(enhanced_memories)
404389
formatted_memories = [format_memory_item(data) for data in deduped_memories]

0 commit comments

Comments
 (0)