55using dependency injection for better modularity and testability.
66"""
77
8+ from typing import Any
9+
810from memos .api .handlers .base_handler import BaseHandler , HandlerDependencies
911from memos .api .product_models import APISearchRequest , SearchResponse
1012from memos .log import get_logger
13+ from memos .memories .textual .tree_text_memory .retrieve .retrieve_utils import (
14+ cosine_similarity_matrix ,
15+ )
1116from memos .multi_mem_cube .composite_cube import CompositeCubeView
1217from memos .multi_mem_cube .single_cube import SingleCubeView
1318from memos .multi_mem_cube .views import MemCubeView
@@ -50,9 +55,19 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
5055 """
5156 self .logger .info (f"[SearchHandler] Search Req is: { search_req } " )
5257
58+ # Increase recall pool if deduplication is enabled to ensure diversity
59+ original_top_k = search_req .top_k
60+ if search_req .dedup == "sim" :
61+ search_req .top_k = original_top_k * 5
62+
5363 cube_view = self ._build_cube_view (search_req )
5464
5565 results = cube_view .search_memories (search_req )
66+ if search_req .dedup == "sim" :
67+ results = self ._dedup_text_memories (results , original_top_k )
68+ self ._strip_embeddings (results )
69+ # Restore original top_k for downstream logic or response metadata
70+ search_req .top_k = original_top_k
5671
5772 self .logger .info (
5873 f"[SearchHandler] Final search results: count={ len (results )} results={ results } "
@@ -63,6 +78,93 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse
6378 data = results ,
6479 )
6580
81+ def _dedup_text_memories (self , results : dict [str , Any ], target_top_k : int ) -> dict [str , Any ]:
82+ buckets = results .get ("text_mem" , [])
83+ if not buckets :
84+ return results
85+
86+ flat : list [tuple [int , dict [str , Any ], float ]] = []
87+ for bucket_idx , bucket in enumerate (buckets ):
88+ for mem in bucket .get ("memories" , []):
89+ score = mem .get ("metadata" , {}).get ("relativity" , 0.0 )
90+ flat .append ((bucket_idx , mem , score ))
91+
92+ if len (flat ) <= 1 :
93+ return results
94+
95+ embeddings = self ._extract_embeddings ([mem for _ , mem , _ in flat ])
96+ if embeddings is None :
97+ documents = [mem .get ("memory" , "" ) for _ , mem , _ in flat ]
98+ embeddings = self .searcher .embedder .embed (documents )
99+
100+ similarity_matrix = cosine_similarity_matrix (embeddings )
101+
102+ indices_by_bucket : dict [int , list [int ]] = {i : [] for i in range (len (buckets ))}
103+ for flat_index , (bucket_idx , _ , _ ) in enumerate (flat ):
104+ indices_by_bucket [bucket_idx ].append (flat_index )
105+
106+ selected_global : list [int ] = []
107+ selected_by_bucket : dict [int , list [int ]] = {i : [] for i in range (len (buckets ))}
108+
109+ ordered_indices = sorted (range (len (flat )), key = lambda idx : flat [idx ][2 ], reverse = True )
110+ for idx in ordered_indices :
111+ bucket_idx = flat [idx ][0 ]
112+ if len (selected_by_bucket [bucket_idx ]) >= target_top_k :
113+ continue
114+ # Use 0.92 threshold strictly
115+ if self ._is_unrelated (idx , selected_global , similarity_matrix , 0.92 ):
116+ selected_by_bucket [bucket_idx ].append (idx )
117+ selected_global .append (idx )
118+
119+ # Removed the 'filling' logic that was pulling back similar items.
120+ # Now it will only return items that truly pass the 0.92 threshold,
121+ # up to target_top_k.
122+
123+ for bucket_idx , bucket in enumerate (buckets ):
124+ selected_indices = selected_by_bucket .get (bucket_idx , [])
125+ bucket ["memories" ] = [flat [i ][1 ] for i in selected_indices ]
126+ return results
127+
128+ @staticmethod
129+ def _is_unrelated (
130+ index : int ,
131+ selected_indices : list [int ],
132+ similarity_matrix : list [list [float ]],
133+ similarity_threshold : float ,
134+ ) -> bool :
135+ return all (similarity_matrix [index ][j ] <= similarity_threshold for j in selected_indices )
136+
137+ @staticmethod
138+ def _max_similarity (
139+ index : int , selected_indices : list [int ], similarity_matrix : list [list [float ]]
140+ ) -> float :
141+ if not selected_indices :
142+ return 0.0
143+ return max (similarity_matrix [index ][j ] for j in selected_indices )
144+
145+ @staticmethod
146+ def _extract_embeddings (memories : list [dict [str , Any ]]) -> list [list [float ]] | None :
147+ embeddings : list [list [float ]] = []
148+ for mem in memories :
149+ embedding = mem .get ("metadata" , {}).get ("embedding" )
150+ if not embedding :
151+ return None
152+ embeddings .append (embedding )
153+ return embeddings
154+
155+ @staticmethod
156+ def _strip_embeddings (results : dict [str , Any ]) -> None :
157+ for bucket in results .get ("text_mem" , []):
158+ for mem in bucket .get ("memories" , []):
159+ metadata = mem .get ("metadata" , {})
160+ if "embedding" in metadata :
161+ metadata ["embedding" ] = []
162+ for bucket in results .get ("tool_mem" , []):
163+ for mem in bucket .get ("memories" , []):
164+ metadata = mem .get ("metadata" , {})
165+ if "embedding" in metadata :
166+ metadata ["embedding" ] = []
167+
66168 def _resolve_cube_ids (self , search_req : APISearchRequest ) -> list [str ]:
67169 """
68170 Normalize target cube ids from search_req.
0 commit comments