Skip to content

Commit a75424e

Browse files
Fix dedup handling in simple search
1 parent ce297ba commit a75424e

File tree

2 files changed

+33
-63
lines changed

2 files changed

+33
-63
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -296,50 +296,6 @@ def _parse_task(
296296

297297
return parsed_goal, query_embedding, context, query
298298

299-
@timed
300-
def _retrieve_simple(
301-
self,
302-
query: str,
303-
top_k: int,
304-
search_filter: dict | None = None,
305-
user_name: str | None = None,
306-
dedup: str | None = None,
307-
**kwargs,
308-
):
309-
"""Retrieve from by keywords and embedding"""
310-
query_words = []
311-
if self.tokenizer:
312-
query_words = self.tokenizer.tokenize_mixed(query)
313-
else:
314-
query_words = query.strip().split()
315-
query_words = [query, *query_words]
316-
logger.info(f"[SIMPLESEARCH] Query words: {query_words}")
317-
query_embeddings = self.embedder.embed(query_words)
318-
319-
items = self.graph_retriever.retrieve_from_mixed(
320-
top_k=top_k * 2,
321-
memory_scope=None,
322-
query_embedding=query_embeddings,
323-
search_filter=search_filter,
324-
user_name=user_name,
325-
use_fast_graph=self.use_fast_graph,
326-
)
327-
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
328-
documents = [getattr(item, "memory", "") for item in items]
329-
documents_embeddings = self.embedder.embed(documents)
330-
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
331-
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
332-
selected_items = [items[i] for i in selected_indices]
333-
logger.info(
334-
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
335-
)
336-
return self.reranker.rerank(
337-
query=query,
338-
query_embedding=query_embeddings[0],
339-
graph_results=selected_items,
340-
top_k=top_k,
341-
)
342-
343299
@timed
344300
def _retrieve_paths(
345301
self,
@@ -699,6 +655,7 @@ def _retrieve_simple(
699655
top_k: int,
700656
search_filter: dict | None = None,
701657
user_name: str | None = None,
658+
dedup: str | None = None,
702659
**kwargs,
703660
):
704661
"""
@@ -721,10 +678,16 @@ def _retrieve_simple(
721678
query_embedding=query_embeddings,
722679
search_filter=search_filter,
723680
user_name=user_name,
681+
use_fast_graph=self.use_fast_graph,
724682
)
725683
logger.info(f"[SIMPLESEARCH] Items count: {len(items)}")
726684
if dedup == "no":
727685
selected_items = items
686+
elif dedup == "sim":
687+
selected_items = self.deduplicate_similar_items(items)
688+
logger.info(
689+
f"[SIMPLESEARCH] after similarity dedup items count: {len(selected_items)}"
690+
)
728691
else:
729692
documents = [getattr(item, "memory", "") for item in items]
730693
documents_embeddings = self.embedder.embed(documents)
@@ -763,12 +726,34 @@ def _deduplicate_similar_results(
763726
embeddings = self.embedder.embed(documents)
764727
similarity_matrix = cosine_similarity_matrix(embeddings)
765728

729+
selected_indices = self._select_unrelated_indices(
730+
similarity_matrix, similarity_threshold
731+
)
732+
return [sorted_results[i] for i in selected_indices]
733+
734+
def deduplicate_similar_items(
735+
self, items: list[TextualMemoryItem], similarity_threshold: float = 0.85
736+
) -> list[TextualMemoryItem]:
737+
"""Deduplicate memory items by semantic similarity while preserving order."""
738+
if len(items) <= 1:
739+
return items
740+
documents = [getattr(item, "memory", "") for item in items]
741+
embeddings = self.embedder.embed(documents)
742+
similarity_matrix = cosine_similarity_matrix(embeddings)
743+
selected_indices = self._select_unrelated_indices(
744+
similarity_matrix, similarity_threshold
745+
)
746+
return [items[i] for i in selected_indices]
747+
748+
@staticmethod
749+
def _select_unrelated_indices(
750+
similarity_matrix: list[list[float]], similarity_threshold: float
751+
) -> list[int]:
766752
selected_indices: list[int] = []
767-
for i in range(len(sorted_results)):
753+
for i in range(len(similarity_matrix)):
768754
if all(similarity_matrix[i][j] <= similarity_threshold for j in selected_indices):
769755
selected_indices.append(i)
770-
771-
return [sorted_results[i] for i in selected_indices]
756+
return selected_indices
772757

773758
@timed
774759
def _sort_and_trim(

src/memos/multi_mem_cube/single_cube.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
MEM_READ_TASK_LABEL,
2424
PREF_ADD_TASK_LABEL,
2525
)
26-
from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import (
27-
cosine_similarity_matrix,
28-
)
2926
from memos.multi_mem_cube.views import MemCubeView
3027
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
3128
from memos.types.general_types import (
@@ -383,22 +380,10 @@ def _dedup_by_content(memories: list) -> list:
383380
unique_memories.append(mem)
384381
return unique_memories
385382

386-
def _dedup_by_similarity(memories: list) -> list:
387-
if len(memories) <= 1:
388-
return memories
389-
documents = [getattr(mem, "memory", "") for mem in memories]
390-
embeddings = self.searcher.embedder.embed(documents)
391-
similarity_matrix = cosine_similarity_matrix(embeddings)
392-
selected_indices = []
393-
for i in range(len(memories)):
394-
if all(similarity_matrix[i][j] <= 0.85 for j in selected_indices):
395-
selected_indices.append(i)
396-
return [memories[i] for i in selected_indices]
397-
398383
if search_req.dedup == "no":
399384
deduped_memories = enhanced_memories
400385
elif search_req.dedup == "sim":
401-
deduped_memories = _dedup_by_similarity(enhanced_memories)
386+
deduped_memories = self.searcher.deduplicate_similar_items(enhanced_memories)
402387
else:
403388
deduped_memories = _dedup_by_content(enhanced_memories)
404389
formatted_memories = [format_memory_item(data) for data in deduped_memories]

0 commit comments

Comments
 (0)