Skip to content

Commit 4dd7f76

Browse files
Wang-Daojiyuan.wang
andauthored
Feat/tool memory (#583)
* function call supoort * add tool parser * rename multi model to modal * rename multi modal * tool mem support * modify multi-modal code * pref support multi-modal messages * modify bug in chat handle * fix pre commit * modify code * add tool search * tool search * add split chunck in system and tool --------- Co-authored-by: yuan.wang <[email protected]>
1 parent 36d0ba0 commit 4dd7f76

File tree

22 files changed

+702
-303
lines changed

22 files changed

+702
-303
lines changed

docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,4 @@ xlrd==2.0.2
160160
xlsxwriter==3.2.5
161161
prometheus-client==0.23.1
162162
pymilvus==2.5.12
163+
langchain-text-splitters==1.0.0

src/memos/api/handlers/chat_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
142142

143143
# Step 2: Build system prompt
144144
system_prompt = self._build_system_prompt(
145-
filtered_memories, search_response.data["pref_string"], chat_req.system_prompt
145+
filtered_memories,
146+
search_response.data.get("pref_string", ""),
147+
chat_req.system_prompt,
146148
)
147149

148150
# Prepare message history
@@ -257,7 +259,7 @@ def generate_chat_response() -> Generator[str, None, None]:
257259
# Step 2: Build system prompt with memories
258260
system_prompt = self._build_system_prompt(
259261
filtered_memories,
260-
search_response.data["pref_string"],
262+
search_response.data.get("pref_string", ""),
261263
chat_req.system_prompt,
262264
)
263265

@@ -449,7 +451,7 @@ def generate_chat_response() -> Generator[str, None, None]:
449451

450452
# Step 2: Build system prompt with memories
451453
system_prompt = self._build_enhance_system_prompt(
452-
filtered_memories, search_response.data["pref_string"]
454+
filtered_memories, search_response.data.get("pref_string", "")
453455
)
454456

455457
# Prepare messages

src/memos/api/handlers/formatters_handler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,37 @@ def post_process_pref_mem(
9090
memories_result["pref_note"] = pref_note
9191

9292
return memories_result
93+
94+
95+
def post_process_textual_mem(
96+
memories_result: dict[str, Any],
97+
text_formatted_mem: list[dict[str, Any]],
98+
mem_cube_id: str,
99+
) -> dict[str, Any]:
100+
"""
101+
Post-process text and tool memory results.
102+
"""
103+
fact_mem = [
104+
mem
105+
for mem in text_formatted_mem
106+
if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
107+
]
108+
tool_mem = [
109+
mem
110+
for mem in text_formatted_mem
111+
if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
112+
]
113+
114+
memories_result["text_mem"].append(
115+
{
116+
"cube_id": mem_cube_id,
117+
"memories": fact_mem,
118+
}
119+
)
120+
memories_result["tool_mem"].append(
121+
{
122+
"cube_id": mem_cube_id,
123+
"memories": tool_mem,
124+
}
125+
)
126+
return memories_result

src/memos/api/product_models.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Import message types from core types module
88
from memos.log import get_logger
9-
from memos.types import MessageDict, PermissionDict, SearchMode
9+
from memos.types import MessageList, MessagesType, PermissionDict, SearchMode
1010

1111

1212
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ class Message(BaseModel):
5656

5757
class MemoryCreate(BaseRequest):
5858
user_id: str = Field(..., description="User ID")
59-
messages: list | None = Field(None, description="List of messages to store.")
59+
messages: MessageList | None = Field(None, description="List of messages to store.")
6060
memory_content: str | None = Field(None, description="Content to store as memory")
6161
doc_path: str | None = Field(None, description="Path to document to store")
6262
mem_cube_id: str | None = Field(None, description="ID of the memory cube")
@@ -83,7 +83,7 @@ class ChatRequest(BaseRequest):
8383
writable_cube_ids: list[str] | None = Field(
8484
None, description="List of cube IDs user can write for multi-cube chat"
8585
)
86-
history: list | None = Field(None, description="Chat history")
86+
history: MessageList | None = Field(None, description="Chat history")
8787
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
8888
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
8989
top_k: int = Field(10, description="Number of results to return")
@@ -165,7 +165,7 @@ class ChatCompleteRequest(BaseRequest):
165165
user_id: str = Field(..., description="User ID")
166166
query: str = Field(..., description="Chat query message")
167167
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
168-
history: list | None = Field(None, description="Chat history")
168+
history: MessageList | None = Field(None, description="Chat history")
169169
internet_search: bool = Field(False, description="Whether to use internet search")
170170
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
171171
top_k: int = Field(10, description="Number of results to return")
@@ -251,7 +251,7 @@ class MemoryCreateRequest(BaseRequest):
251251
"""Request model for creating memories."""
252252

253253
user_id: str = Field(..., description="User ID")
254-
messages: str | list | None = Field(None, description="List of messages to store.")
254+
messages: str | MessagesType | None = Field(None, description="List of messages to store.")
255255
memory_content: str | None = Field(None, description="Memory content to store")
256256
doc_path: str | None = Field(None, description="Path to document to store")
257257
mem_cube_id: str | None = Field(None, description="Cube ID")
@@ -326,6 +326,21 @@ class APISearchRequest(BaseRequest):
326326
),
327327
)
328328

329+
search_tool_memory: bool = Field(
330+
True,
331+
description=(
332+
"Whether to retrieve tool memories along with general memories. "
333+
"If enabled, the system will automatically recall tool memories "
334+
"relevant to the query. Default: True."
335+
),
336+
)
337+
338+
tool_mem_top_k: int = Field(
339+
6,
340+
ge=0,
341+
description="Number of tool memories to retrieve (top-K). Default: 6.",
342+
)
343+
329344
# ==== Filter conditions ====
330345
# TODO: maybe add detailed description later
331346
filter: dict[str, Any] | None = Field(
@@ -360,7 +375,7 @@ class APISearchRequest(BaseRequest):
360375
)
361376

362377
# ==== Context ====
363-
chat_history: list | None = Field(
378+
chat_history: MessageList | None = Field(
364379
None,
365380
description=(
366381
"Historical chat messages used internally by algorithms. "
@@ -490,7 +505,7 @@ class APIADDRequest(BaseRequest):
490505
)
491506

492507
# ==== Input content ====
493-
messages: str | list | None = Field(
508+
messages: MessagesType | None = Field(
494509
None,
495510
description=(
496511
"List of messages to store. Supports: "
@@ -506,7 +521,7 @@ class APIADDRequest(BaseRequest):
506521
)
507522

508523
# ==== Chat history ====
509-
chat_history: list | None = Field(
524+
chat_history: MessageList | None = Field(
510525
None,
511526
description=(
512527
"Historical chat messages used internally by algorithms. "
@@ -636,7 +651,7 @@ class APIFeedbackRequest(BaseRequest):
636651
"default_session", description="Session ID for soft-filtering memories"
637652
)
638653
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
639-
history: list[MessageDict] | None = Field(..., description="Chat history")
654+
history: MessageList | None = Field(..., description="Chat history")
640655
retrieved_memory_ids: list[str] | None = Field(
641656
None, description="Retrieved memory ids at last turn"
642657
)
@@ -670,7 +685,7 @@ class APIChatCompleteRequest(BaseRequest):
670685
writable_cube_ids: list[str] | None = Field(
671686
None, description="List of cube IDs user can write for multi-cube chat"
672687
)
673-
history: list | None = Field(None, description="Chat history")
688+
history: MessageList | None = Field(None, description="Chat history")
674689
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
675690
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
676691
top_k: int = Field(10, description="Number of results to return")
@@ -739,7 +754,7 @@ class SuggestionRequest(BaseRequest):
739754
user_id: str = Field(..., description="User ID")
740755
mem_cube_id: str = Field(..., description="Cube ID")
741756
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
742-
message: list | None = Field(None, description="List of messages to store.")
757+
message: MessagesType | None = Field(None, description="List of messages to store.")
743758

744759

745760
# ─── MemOS Client Response Models ──────────────────────────────────────────────

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import concurrent.futures
2+
import json
23
import traceback
34

45
from typing import Any
@@ -7,8 +8,9 @@
78
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
89
from memos.context.context import ContextThreadPoolExecutor
910
from memos.mem_reader.read_multi_modal import MultiModalParser
10-
from memos.mem_reader.simple_struct import SimpleStructMemReader
11+
from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang
1112
from memos.memories.textual.item import TextualMemoryItem
13+
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1214
from memos.types import MessagesType
1315
from memos.utils import timed
1416

@@ -297,6 +299,61 @@ def _process_string_fine(
297299

298300
return fine_memory_items
299301

302+
def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict:
303+
"""
304+
Generete tool trajectory experience item by llm.
305+
"""
306+
try:
307+
lang = detect_lang(mem_str)
308+
template = TOOL_TRAJECTORY_PROMPT_ZH if lang == "zh" else TOOL_TRAJECTORY_PROMPT_EN
309+
prompt = template.replace("{messages}", mem_str)
310+
rsp = self.llm.generate([{"role": "user", "content": prompt}])
311+
rsp = rsp.replace("```json", "").replace("```", "")
312+
return json.loads(rsp)
313+
except Exception as e:
314+
logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}")
315+
return []
316+
317+
def _process_tool_trajectory_fine(
318+
self,
319+
fast_memory_items: list[TextualMemoryItem],
320+
info: dict[str, Any],
321+
) -> list[TextualMemoryItem]:
322+
"""
323+
Process tool trajectory memory items through LLM to generate fine mode memories.
324+
"""
325+
if not fast_memory_items:
326+
return []
327+
328+
fine_memory_items = []
329+
330+
for fast_item in fast_memory_items:
331+
# Extract memory text (string content)
332+
mem_str = fast_item.memory or ""
333+
if not mem_str.strip() or "tool:" not in mem_str:
334+
continue
335+
try:
336+
resp = self._get_llm_tool_trajectory_response(mem_str)
337+
except Exception as e:
338+
logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}")
339+
continue
340+
for m in resp:
341+
try:
342+
# Normalize memory_type (same as simple_struct)
343+
memory_type = "ToolTrajectoryMemory"
344+
345+
node = self._make_memory_item(
346+
value=m.get("trajectory", ""),
347+
info=info,
348+
memory_type=memory_type,
349+
tool_used_status=m.get("tool_used_status", []),
350+
)
351+
fine_memory_items.append(node)
352+
except Exception as e:
353+
logger.error(f"[MultiModalFine] parse error for tool trajectory: {e}")
354+
355+
return fine_memory_items
356+
300357
@timed
301358
def _process_multi_modal_data(
302359
self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs
@@ -339,6 +396,11 @@ def _process_multi_modal_data(
339396
)
340397
fine_memory_items.extend(fine_memory_items_string_parser)
341398

399+
fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
400+
fast_memory_items, info
401+
)
402+
fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)
403+
342404
# Part B: get fine multimodal items
343405
for fast_item in fast_memory_items:
344406
sources = fast_item.metadata.sources
@@ -377,6 +439,12 @@ def _process_transfer_multi_modal_data(
377439
# Part A: call llm
378440
fine_memory_items_string_parser = self._process_string_fine([raw_node], info, custom_tags)
379441
fine_memory_items.extend(fine_memory_items_string_parser)
442+
443+
fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
444+
[raw_node], info
445+
)
446+
fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)
447+
380448
# Part B: get fine multimodal items
381449
for source in sources:
382450
items = self.multi_modal_parser.process_transfer(

0 commit comments

Comments
 (0)