Skip to content

Commit da9d843

Browse files
author
yuan.wang
committed
tool search
1 parent a04b5f5 commit da9d843

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

src/memos/api/product_models.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Import message types from core types module
88
from memos.log import get_logger
9-
from memos.types import MessageDict, PermissionDict, SearchMode
9+
from memos.types import MessageList, MessagesType, PermissionDict, SearchMode
1010

1111

1212
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ class Message(BaseModel):
5656

5757
class MemoryCreate(BaseRequest):
5858
user_id: str = Field(..., description="User ID")
59-
messages: list | None = Field(None, description="List of messages to store.")
59+
messages: MessageList | None = Field(None, description="List of messages to store.")
6060
memory_content: str | None = Field(None, description="Content to store as memory")
6161
doc_path: str | None = Field(None, description="Path to document to store")
6262
mem_cube_id: str | None = Field(None, description="ID of the memory cube")
@@ -83,7 +83,7 @@ class ChatRequest(BaseRequest):
8383
writable_cube_ids: list[str] | None = Field(
8484
None, description="List of cube IDs user can write for multi-cube chat"
8585
)
86-
history: list | None = Field(None, description="Chat history")
86+
history: MessageList | None = Field(None, description="Chat history")
8787
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
8888
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
8989
top_k: int = Field(10, description="Number of results to return")
@@ -165,7 +165,7 @@ class ChatCompleteRequest(BaseRequest):
165165
user_id: str = Field(..., description="User ID")
166166
query: str = Field(..., description="Chat query message")
167167
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
168-
history: list | None = Field(None, description="Chat history")
168+
history: MessageList | None = Field(None, description="Chat history")
169169
internet_search: bool = Field(False, description="Whether to use internet search")
170170
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
171171
top_k: int = Field(10, description="Number of results to return")
@@ -251,7 +251,7 @@ class MemoryCreateRequest(BaseRequest):
251251
"""Request model for creating memories."""
252252

253253
user_id: str = Field(..., description="User ID")
254-
messages: str | list | None = Field(None, description="List of messages to store.")
254+
messages: str | MessagesType | None = Field(None, description="List of messages to store.")
255255
memory_content: str | None = Field(None, description="Memory content to store")
256256
doc_path: str | None = Field(None, description="Path to document to store")
257257
mem_cube_id: str | None = Field(None, description="Cube ID")
@@ -375,7 +375,7 @@ class APISearchRequest(BaseRequest):
375375
)
376376

377377
# ==== Context ====
378-
chat_history: list | None = Field(
378+
chat_history: MessageList | None = Field(
379379
None,
380380
description=(
381381
"Historical chat messages used internally by algorithms. "
@@ -505,7 +505,7 @@ class APIADDRequest(BaseRequest):
505505
)
506506

507507
# ==== Input content ====
508-
messages: str | list | None = Field(
508+
messages: MessagesType | None = Field(
509509
None,
510510
description=(
511511
"List of messages to store. Supports: "
@@ -521,7 +521,7 @@ class APIADDRequest(BaseRequest):
521521
)
522522

523523
# ==== Chat history ====
524-
chat_history: list | None = Field(
524+
chat_history: MessageList | None = Field(
525525
None,
526526
description=(
527527
"Historical chat messages used internally by algorithms. "
@@ -651,7 +651,7 @@ class APIFeedbackRequest(BaseRequest):
651651
"default_session", description="Session ID for soft-filtering memories"
652652
)
653653
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
654-
history: list[MessageDict] | None = Field(..., description="Chat history")
654+
history: MessageList | None = Field(..., description="Chat history")
655655
retrieved_memory_ids: list[str] | None = Field(
656656
None, description="Retrieved memory ids at last turn"
657657
)
@@ -686,7 +686,7 @@ class APIChatCompleteRequest(BaseRequest):
686686
writable_cube_ids: list[str] | None = Field(
687687
None, description="List of cube IDs user can write for multi-cube chat"
688688
)
689-
history: list | None = Field(None, description="Chat history")
689+
history: MessageList | None = Field(None, description="Chat history")
690690
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
691691
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
692692
top_k: int = Field(10, description="Number of results to return")
@@ -755,7 +755,7 @@ class SuggestionRequest(BaseRequest):
755755
user_id: str = Field(..., description="User ID")
756756
mem_cube_id: str = Field(..., description="Cube ID")
757757
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
758-
message: list | None = Field(None, description="List of messages to store.")
758+
message: MessagesType | None = Field(None, description="List of messages to store.")
759759

760760

761761
# ─── MemOS Client Response Models ──────────────────────────────────────────────

src/memos/mem_reader/read_multi_modal/system_parser.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,21 @@ def parse_fine(
126126
logger.warning(f"[SystemParser] Tool schema must be a list[dict]: {content}")
127127
return []
128128

129+
info_ = info.copy()
130+
user_id = info_.pop("user_id", "")
131+
session_id = info_.pop("session_id", "")
132+
129133
return [
130134
TextualMemoryItem(
131135
id=str(uuid.uuid4()),
132136
memory=json.dumps(schema),
133137
metadata=TreeNodeTextualMemoryMetadata(
138+
user_id=user_id,
139+
session_id=session_id,
134140
memory_type="ToolSchemaMemory",
141+
status="activated",
142+
embedding=self.embedder.embed([json.dumps(schema)])[0],
143+
info=info_,
135144
),
136145
)
137146
for schema in tool_schema

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,11 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non
181181
working_id = str(uuid.uuid4())
182182

183183
with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex:
184-
f_working = ex.submit(
185-
self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id
186-
)
187-
futures.append(("working", f_working))
184+
if memory.metadata.memory_type not in ("ToolSchemaMemory", "ToolTrajectoryMemory"):
185+
f_working = ex.submit(
186+
self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id
187+
)
188+
futures.append(("working", f_working))
188189

189190
if memory.metadata.memory_type in (
190191
"LongTermMemory",

0 commit comments

Comments
 (0)