Skip to content

Commit 73c9fa1

Browse files
authored
feat: recall and searcher use parallel (#337)
* feat: recall and searcher use parallel * feat: recall and searcher format
1 parent 4a4abca commit 73c9fa1

File tree

2 files changed

+59
-30
lines changed

2 files changed

+59
-30
lines changed

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,39 @@ def search_single(vec, filt=None):
214214
or []
215215
)
216216

217-
all_hits = []
218-
# Path A: without filter
219-
with ContextThreadPoolExecutor() as executor:
220-
futures = [
221-
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
222-
]
223-
for f in concurrent.futures.as_completed(futures):
224-
all_hits.extend(f.result() or [])
225-
226-
# Path B: with filter
227-
if search_filter:
217+
def search_path_a():
218+
"""Path A: search without filter"""
219+
path_a_hits = []
220+
with ContextThreadPoolExecutor() as executor:
221+
futures = [
222+
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
223+
]
224+
for f in concurrent.futures.as_completed(futures):
225+
path_a_hits.extend(f.result() or [])
226+
return path_a_hits
227+
228+
def search_path_b():
229+
"""Path B: search with filter"""
230+
if not search_filter:
231+
return []
232+
path_b_hits = []
228233
with ContextThreadPoolExecutor() as executor:
229234
futures = [
230235
executor.submit(search_single, vec, search_filter)
231236
for vec in query_embedding[:max_num]
232237
]
233238
for f in concurrent.futures.as_completed(futures):
234-
all_hits.extend(f.result() or [])
239+
path_b_hits.extend(f.result() or [])
240+
return path_b_hits
241+
242+
# Execute both paths concurrently
243+
all_hits = []
244+
with ContextThreadPoolExecutor(max_workers=2) as executor:
245+
path_a_future = executor.submit(search_path_a)
246+
path_b_future = executor.submit(search_path_b)
247+
248+
all_hits.extend(path_a_future.result())
249+
all_hits.extend(path_b_future.result())
235250

236251
if not all_hits:
237252
return []

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,24 +269,38 @@ def _retrieve_from_long_term_and_user(
269269
):
270270
"""Retrieve and rerank from LongTermMemory and UserMemory"""
271271
results = []
272-
if memory_type in ["All", "LongTermMemory"]:
273-
results += self.graph_retriever.retrieve(
274-
query=query,
275-
parsed_goal=parsed_goal,
276-
query_embedding=query_embedding,
277-
top_k=top_k * 2,
278-
memory_scope="LongTermMemory",
279-
search_filter=search_filter,
280-
)
281-
if memory_type in ["All", "UserMemory"]:
282-
results += self.graph_retriever.retrieve(
283-
query=query,
284-
parsed_goal=parsed_goal,
285-
query_embedding=query_embedding,
286-
top_k=top_k * 2,
287-
memory_scope="UserMemory",
288-
search_filter=search_filter,
289-
)
272+
tasks = []
273+
274+
with ContextThreadPoolExecutor(max_workers=2) as executor:
275+
if memory_type in ["All", "LongTermMemory"]:
276+
tasks.append(
277+
executor.submit(
278+
self.graph_retriever.retrieve,
279+
query=query,
280+
parsed_goal=parsed_goal,
281+
query_embedding=query_embedding,
282+
top_k=top_k * 2,
283+
memory_scope="LongTermMemory",
284+
search_filter=search_filter,
285+
)
286+
)
287+
if memory_type in ["All", "UserMemory"]:
288+
tasks.append(
289+
executor.submit(
290+
self.graph_retriever.retrieve,
291+
query=query,
292+
parsed_goal=parsed_goal,
293+
query_embedding=query_embedding,
294+
top_k=top_k * 2,
295+
memory_scope="UserMemory",
296+
search_filter=search_filter,
297+
)
298+
)
299+
300+
# Collect results from all tasks
301+
for task in tasks:
302+
results.extend(task.result())
303+
290304
return self.reranker.rerank(
291305
query=query,
292306
query_embedding=query_embedding[0],

0 commit comments

Comments
 (0)