Skip to content

Commit a04b5f5

Browse files
author
yuan.wang
committed
add tool search
1 parent 2e68bac commit a04b5f5

File tree

8 files changed

+236
-116
lines changed

8 files changed

+236
-116
lines changed

src/memos/api/handlers/formatters_handler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,37 @@ def post_process_pref_mem(
9090
memories_result["pref_note"] = pref_note
9191

9292
return memories_result
93+
94+
95+
def post_process_textual_mem(
96+
memories_result: dict[str, Any],
97+
text_formatted_mem: list[dict[str, Any]],
98+
mem_cube_id: str,
99+
) -> dict[str, Any]:
100+
"""
101+
Post-process text and tool memory results.
102+
"""
103+
fact_mem = [
104+
mem
105+
for mem in text_formatted_mem
106+
if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
107+
]
108+
tool_mem = [
109+
mem
110+
for mem in text_formatted_mem
111+
if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
112+
]
113+
114+
memories_result["text_mem"].append(
115+
{
116+
"cube_id": mem_cube_id,
117+
"memories": fact_mem,
118+
}
119+
)
120+
memories_result["tool_mem"].append(
121+
{
122+
"cube_id": mem_cube_id,
123+
"memories": tool_mem,
124+
}
125+
)
126+
return memories_result

src/memos/api/product_models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,21 @@ class APISearchRequest(BaseRequest):
326326
),
327327
)
328328

329+
search_tool_memory: bool = Field(
330+
True,
331+
description=(
332+
"Whether to retrieve tool memories along with general memories. "
333+
"If enabled, the system will automatically recall tool memories "
334+
"relevant to the query. Default: True."
335+
),
336+
)
337+
338+
tool_mem_top_k: int = Field(
339+
6,
340+
ge=0,
341+
description="Number of tool memories to retrieve (top-K). Default: 6.",
342+
)
343+
329344
# ==== Filter conditions ====
330345
# TODO: maybe add detailed description later
331346
filter: dict[str, Any] | None = Field(

src/memos/mem_reader/read_multi_modal/tool_parser.py

Lines changed: 2 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -135,91 +135,6 @@ def rebuild_from_source(
135135
) -> ChatCompletionToolMessageParam:
136136
"""Rebuild tool message from SourceMessage."""
137137

138-
# Priority 1: Use original_part if available
139-
if hasattr(source, "original_part") and source.original_part:
140-
original = source.original_part
141-
# If it's a content part, wrap it in a message
142-
if isinstance(original, dict) and "type" in original:
143-
return {
144-
"role": source.role or "user",
145-
"tool_call_id": source.tool_call_id or "",
146-
"content": [original],
147-
"chat_time": source.chat_time,
148-
"message_id": source.message_id,
149-
}
150-
# If it's already a full message, return it
151-
if isinstance(original, dict) and "role" in original:
152-
return original
153-
154-
# Priority 2: Rebuild from source fields
155-
if source.type == "text":
156-
return {
157-
"role": source.role or "tool",
158-
"content": [
159-
{
160-
"type": "text",
161-
"text": source.content or "",
162-
}
163-
],
164-
"chat_time": source.chat_time,
165-
"message_id": source.message_id,
166-
}
167-
elif source.type == "file":
168-
return {
169-
"role": source.role or "tool",
170-
"content": [
171-
{
172-
"type": "file",
173-
"file": {
174-
"file_id": source.file_id or "",
175-
"filename": source.filename or "",
176-
"file_data": source.content or "",
177-
},
178-
}
179-
],
180-
"chat_time": source.chat_time,
181-
"message_id": source.message_id,
182-
}
183-
elif source.type == "image_url":
184-
return {
185-
"role": source.role or "tool",
186-
"content": [
187-
{
188-
"type": "image_url",
189-
"image_url": {
190-
"url": source.content or "",
191-
"detail": source.detail or "auto",
192-
},
193-
}
194-
],
195-
"chat_time": source.chat_time,
196-
"message_id": source.message_id,
197-
}
198-
elif source.type == "input_audio":
199-
return {
200-
"role": source.role or "tool",
201-
"content": [
202-
{
203-
"type": "input_audio",
204-
"input_audio": {
205-
"data": source.content or "",
206-
"format": source.format or "wav",
207-
},
208-
}
209-
],
210-
"chat_time": source.chat_time,
211-
"message_id": source.message_id,
212-
}
213-
214-
# Simple text message
215-
return {
216-
"role": "tool",
217-
"content": source.content or "",
218-
"tool_call_id": source.message_id or "",
219-
"chat_time": source.chat_time,
220-
"message_id": source.message_id,
221-
}
222-
223138
def parse_fast(
224139
self,
225140
message: ChatCompletionToolMessageParam,
@@ -261,25 +176,5 @@ def parse_fine(
261176
info: dict[str, Any],
262177
**kwargs,
263178
) -> list[TextualMemoryItem]:
264-
content = message.get("content", "")
265-
if isinstance(content, list):
266-
part_type = content[0].get("type", "")
267-
if part_type == "text":
268-
# text will fine parse in full chat content, no need to parse specially
269-
return []
270-
elif part_type == "file":
271-
# use file content parser to parse file content, no need to parse here
272-
return []
273-
elif part_type == "image_url":
274-
# TODO: use multi-modal llm to generate mem by image url
275-
content = content[0].get("image_url", {}).get("url", "")
276-
return []
277-
elif part_type == "input_audio":
278-
# TODO: unsupport audio for now
279-
return []
280-
else:
281-
logger.warning(f"[ToolParser] Unsupported part type: {part_type}")
282-
return []
283-
else:
284-
# simple string content message, fine parse in full chat content, no need to parse specially
285-
return []
179+
# tool message no special multimodal handling is required in fine mode.
180+
return []

src/memos/mem_scheduler/optimized_scheduler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def mix_search_memories(
159159
search_filter=search_filter,
160160
search_priority=search_priority,
161161
info=info,
162+
search_tool_memory=search_req.search_tool_memory,
163+
tool_mem_top_k=search_req.tool_mem_top_k,
162164
)
163165

164166
# Try to get pre-computed memories if available
@@ -182,6 +184,8 @@ def mix_search_memories(
182184
top_k=search_req.top_k,
183185
user_name=user_context.mem_cube_id,
184186
info=info,
187+
search_tool_memory=search_req.search_tool_memory,
188+
tool_mem_top_k=search_req.tool_mem_top_k,
185189
)
186190
memories = merged_memories[: search_req.top_k]
187191

src/memos/memories/textual/tree.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def search(
166166
search_priority: dict | None = None,
167167
search_filter: dict | None = None,
168168
user_name: str | None = None,
169+
search_tool_memory: bool = False,
170+
tool_mem_top_k: int = 6,
169171
**kwargs,
170172
) -> list[TextualMemoryItem]:
171173
"""Search for memories based on a query.
@@ -223,6 +225,8 @@ def search(
223225
search_priority,
224226
user_name=user_name,
225227
plugin=kwargs.get("plugin", False),
228+
search_tool_memory=search_tool_memory,
229+
tool_mem_top_k=tool_mem_top_k,
226230
)
227231

228232
def get_relevant_subgraph(

src/memos/memories/textual/tree_text_memory/retrieve/recall.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@ def retrieve(
5959
Returns:
6060
list: Combined memory items.
6161
"""
62-
if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]:
62+
if memory_scope not in [
63+
"WorkingMemory",
64+
"LongTermMemory",
65+
"UserMemory",
66+
"ToolSchemaMemory",
67+
"ToolTrajectoryMemory",
68+
]:
6369
raise ValueError(f"Unsupported memory scope: {memory_scope}")
6470

6571
if memory_scope == "WorkingMemory":

0 commit comments

Comments
 (0)