Skip to content

Commit 20438e9

Browse files
committed
Merge branch 'dev_new' into feat/deep-search
2 parents 953872e + 7d34e65 commit 20438e9

File tree

13 files changed

+770
-49
lines changed

13 files changed

+770
-49
lines changed

examples/mem_reader/compare_simple_vs_multimodal.py

Lines changed: 461 additions & 0 deletions
Large diffs are not rendered by default.

src/memos/api/product_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest):
469469
),
470470
)
471471

472-
info: dict[str, str] | None = Field(
472+
info: dict[str, Any] | None = Field(
473473
None,
474474
description=(
475475
"Additional metadata for the add request. "

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,57 @@ def _build_window_from_items(
171171

172172
return aggregated_item
173173

174+
def _process_string_fine(
175+
self,
176+
fast_memory_items: list[TextualMemoryItem],
177+
info: dict[str, Any],
178+
custom_tags: list[str] | None = None,
179+
) -> list[TextualMemoryItem]:
180+
"""
181+
Process fast mode memory items through LLM to generate fine mode memories.
182+
"""
183+
if not fast_memory_items:
184+
return []
185+
186+
fine_memory_items = []
187+
188+
for fast_item in fast_memory_items:
189+
# Extract memory text (string content)
190+
mem_str = fast_item.memory or ""
191+
if not mem_str.strip():
192+
continue
193+
sources = fast_item.metadata.sources or []
194+
if not isinstance(sources, list):
195+
sources = [sources]
196+
try:
197+
resp = self._get_llm_response(mem_str, custom_tags)
198+
except Exception as e:
199+
logger.error(f"[MultiModalFine] Error calling LLM: {e}")
200+
continue
201+
for m in resp.get("memory list", []):
202+
try:
203+
# Normalize memory_type (same as simple_struct)
204+
memory_type = (
205+
m.get("memory_type", "LongTermMemory")
206+
.replace("长期记忆", "LongTermMemory")
207+
.replace("用户记忆", "UserMemory")
208+
)
209+
# Create fine mode memory item (same as simple_struct)
210+
node = self._make_memory_item(
211+
value=m.get("value", ""),
212+
info=info,
213+
memory_type=memory_type,
214+
tags=m.get("tags", []),
215+
key=m.get("key", ""),
216+
sources=sources, # Preserve sources from fast item
217+
background=resp.get("summary", ""),
218+
)
219+
fine_memory_items.append(node)
220+
except Exception as e:
221+
logger.error(f"[MultiModalFine] parse error: {e}")
222+
223+
return fine_memory_items
224+
174225
@timed
175226
def _process_multi_modal_data(
176227
self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs
@@ -208,21 +259,21 @@ def _process_multi_modal_data(
208259
if mode == "fast":
209260
return fast_memory_items
210261
else:
211-
# TODO: parallel call llm and get fine multimodal items
212262
# Part A: call llm
213263
fine_memory_items = []
214-
fine_memory_items_string_parser = fast_memory_items
264+
fine_memory_items_string_parser = self._process_string_fine(
265+
fast_memory_items, info, custom_tags
266+
)
215267
fine_memory_items.extend(fine_memory_items_string_parser)
216-
# Part B: get fine multimodal items
217268

269+
# Part B: get fine multimodal items
218270
for fast_item in fast_memory_items:
219271
sources = fast_item.metadata.sources
220272
for source in sources:
221273
items = self.multi_modal_parser.process_transfer(
222274
source, context_items=[fast_item], custom_tags=custom_tags
223275
)
224276
fine_memory_items.extend(items)
225-
logger.warning("Not Implemented Now!")
226277
return fine_memory_items
227278

228279
@timed
@@ -251,7 +302,7 @@ def _process_transfer_multi_modal_data(
251302

252303
fine_memory_items = []
253304
# Part A: call llm
254-
fine_memory_items_string_parser = []
305+
fine_memory_items_string_parser = self._process_string_fine([raw_node], info, custom_tags)
255306
fine_memory_items.extend(fine_memory_items_string_parser)
256307
# Part B: get fine multimodal items
257308
for source in sources:

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def mix_search_memories(
138138
target_session_id = search_req.session_id
139139
if not target_session_id:
140140
target_session_id = "default_session"
141-
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
141+
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
142+
search_filter = search_req.filter
142143

143144
# Rerank Memories - reranker expects TextualMemoryItem objects
144145

@@ -155,6 +156,7 @@ def mix_search_memories(
155156
mode=SearchMode.FAST,
156157
manual_close_internet=not search_req.internet_search,
157158
search_filter=search_filter,
159+
search_priority=search_priority,
158160
info=info,
159161
)
160162

@@ -178,7 +180,7 @@ def mix_search_memories(
178180
query=search_req.query, # Use search_req.query instead of undefined query
179181
graph_results=history_memories, # Pass TextualMemoryItem objects directly
180182
top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
181-
search_filter=search_filter,
183+
search_priority=search_priority,
182184
)
183185
logger.info(f"Reranked {len(sorted_history_memories)} history memories.")
184186
processed_hist_mem = self.searcher.post_retrieve(
@@ -234,6 +236,7 @@ def mix_search_memories(
234236
mode=SearchMode.FAST,
235237
memory_type="All",
236238
search_filter=search_filter,
239+
search_priority=search_priority,
237240
info=info,
238241
)
239242
else:

src/memos/memories/textual/prefer_text_memory/retrievers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No
1717

1818
@abstractmethod
1919
def retrieve(
20-
self, query: str, top_k: int, info: dict[str, Any] | None = None
20+
self,
21+
query: str,
22+
top_k: int,
23+
info: dict[str, Any] | None = None,
24+
search_filter: dict[str, Any] | None = None,
2125
) -> list[TextualMemoryItem]:
2226
"""Retrieve memories from the retriever."""
2327

@@ -76,14 +80,19 @@ def _original_text_reranker(
7680
return prefs_mem
7781

7882
def retrieve(
79-
self, query: str, top_k: int, info: dict[str, Any] | None = None
83+
self,
84+
query: str,
85+
top_k: int,
86+
info: dict[str, Any] | None = None,
87+
search_filter: dict[str, Any] | None = None,
8088
) -> list[TextualMemoryItem]:
8189
"""Retrieve memories from the naive retriever."""
8290
# TODO: un-support rewrite query and session filter now
8391
if info:
8492
info = info.copy() # Create a copy to avoid modifying the original
8593
info.pop("chat_history", None)
8694
info.pop("session_id", None)
95+
search_filter = {"and": [info, search_filter]}
8796
query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings
8897
query_embedding = query_embeddings[0] # Get the first (and only) embedding
8998

@@ -96,15 +105,15 @@ def retrieve(
96105
query,
97106
"explicit_preference",
98107
top_k * 2,
99-
info,
108+
search_filter,
100109
)
101110
future_implicit = executor.submit(
102111
self.vector_db.search,
103112
query_embedding,
104113
query,
105114
"implicit_preference",
106115
top_k * 2,
107-
info,
116+
search_filter,
108117
)
109118

110119
# Wait for all results

src/memos/memories/textual/preference.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def get_memory(
7676
"""
7777
return self.extractor.extract(messages, type, info)
7878

79-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
79+
def search(
80+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
81+
) -> list[TextualMemoryItem]:
8082
"""Search for memories based on a query.
8183
Args:
8284
query (str): The query to search for.
@@ -85,7 +87,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
8587
Returns:
8688
list[TextualMemoryItem]: List of matching memories.
8789
"""
88-
return self.retriever.retrieve(query, top_k, info)
90+
logger.info(f"search_filter for preference memory: {search_filter}")
91+
return self.retriever.retrieve(query, top_k, info, search_filter)
8992

9093
def load(self, dir: str) -> None:
9194
"""Load memories from the specified directory.

src/memos/memories/textual/simple_preference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def get_memory(
5050
"""
5151
return self.extractor.extract(messages, type, info)
5252

53-
def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
53+
def search(
54+
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
55+
) -> list[TextualMemoryItem]:
5456
"""Search for memories based on a query.
5557
Args:
5658
query (str): The query to search for.
@@ -59,7 +61,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
5961
Returns:
6062
list[TextualMemoryItem]: List of matching memories.
6163
"""
62-
return self.retriever.retrieve(query, top_k, info)
64+
return self.retriever.retrieve(query, top_k, info, search_filter)
6365

6466
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
6567
"""Add memories.

src/memos/memories/textual/tree.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def search(
162162
mode: str = "fast",
163163
memory_type: str = "All",
164164
manual_close_internet: bool = True,
165+
search_priority: dict | None = None,
165166
search_filter: dict | None = None,
166167
user_name: str | None = None,
167168
) -> list[TextualMemoryItem]:
@@ -209,7 +210,14 @@ def search(
209210
manual_close_internet=manual_close_internet,
210211
)
211212
return searcher.search(
212-
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
213+
query,
214+
top_k,
215+
info,
216+
mode,
217+
memory_type,
218+
search_filter,
219+
search_priority,
220+
user_name=user_name,
213221
)
214222

215223
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def retrieve(
3838
memory_scope: str,
3939
query_embedding: list[list[float]] | None = None,
4040
search_filter: dict | None = None,
41+
search_priority: dict | None = None,
4142
user_name: str | None = None,
4243
id_filter: dict | None = None,
4344
use_fast_graph: bool = False,
@@ -62,9 +63,12 @@ def retrieve(
6263
raise ValueError(f"Unsupported memory scope: {memory_scope}")
6364

6465
if memory_scope == "WorkingMemory":
65-
# For working memory, retrieve all entries (no filtering)
66+
# For working memory, retrieve all entries (no session-oriented filtering)
6667
working_memories = self.graph_store.get_all_memory_items(
67-
scope="WorkingMemory", include_embedding=False, user_name=user_name
68+
scope="WorkingMemory",
69+
include_embedding=False,
70+
user_name=user_name,
71+
filter=search_filter,
6872
)
6973
return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]]
7074

@@ -84,6 +88,7 @@ def retrieve(
8488
memory_scope,
8589
top_k,
8690
search_filter=search_filter,
91+
search_priority=search_priority,
8792
user_name=user_name,
8893
)
8994
if self.use_bm25:
@@ -274,6 +279,7 @@ def _vector_recall(
274279
status: str = "activated",
275280
cube_name: str | None = None,
276281
search_filter: dict | None = None,
282+
search_priority: dict | None = None,
277283
user_name: str | None = None,
278284
) -> list[TextualMemoryItem]:
279285
"""
@@ -283,39 +289,41 @@ def _vector_recall(
283289
if not query_embedding:
284290
return []
285291

286-
def search_single(vec, filt=None):
292+
def search_single(vec, search_priority=None, search_filter=None):
287293
return (
288294
self.graph_store.search_by_embedding(
289295
vector=vec,
290296
top_k=top_k,
291297
status=status,
292298
scope=memory_scope,
293299
cube_name=cube_name,
294-
search_filter=filt,
300+
search_filter=search_priority,
301+
filter=search_filter,
295302
user_name=user_name,
296303
)
297304
or []
298305
)
299306

300307
def search_path_a():
301-
"""Path A: search without filter"""
308+
"""Path A: search without priority"""
302309
path_a_hits = []
303310
with ContextThreadPoolExecutor() as executor:
304311
futures = [
305-
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
312+
executor.submit(search_single, vec, None, search_filter)
313+
for vec in query_embedding[:max_num]
306314
]
307315
for f in concurrent.futures.as_completed(futures):
308316
path_a_hits.extend(f.result() or [])
309317
return path_a_hits
310318

311319
def search_path_b():
312-
"""Path B: search with filter"""
313-
if not search_filter:
320+
"""Path B: search with priority"""
321+
if not search_priority:
314322
return []
315323
path_b_hits = []
316324
with ContextThreadPoolExecutor() as executor:
317325
futures = [
318-
executor.submit(search_single, vec, search_filter)
326+
executor.submit(search_single, vec, search_priority, search_filter)
319327
for vec in query_embedding[:max_num]
320328
]
321329
for f in concurrent.futures.as_completed(futures):

0 commit comments

Comments
 (0)