Skip to content

Commit 64383fb

Browse files
committed
Merge branch 'dev_new' into feat/deep-search
2 parents 78c1582 + 53aa48c commit 64383fb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2089
-677
lines changed

docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,4 @@ xlrd==2.0.2
160160
xlsxwriter==3.2.5
161161
prometheus-client==0.23.1
162162
pymilvus==2.5.12
163+
langchain-text-splitters==1.0.0

src/memos/api/handlers/add_handler.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from pydantic import validate_call
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
1012
from memos.memories.textual.item import (
@@ -13,6 +15,7 @@
1315
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1416
from memos.multi_mem_cube.single_cube import SingleCubeView
1517
from memos.multi_mem_cube.views import MemCubeView
18+
from memos.types import MessageList
1619

1720

1821
class AddHandler(BaseHandler):
@@ -60,38 +63,45 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
6063

6164
cube_view = self._build_cube_view(add_req)
6265

66+
@validate_call
67+
def _check_messages(messages: MessageList) -> None:
68+
pass
69+
6370
if add_req.is_feedback:
64-
chat_history = add_req.chat_history
65-
messages = add_req.messages
66-
if chat_history is None:
67-
chat_history = []
68-
if messages is None:
69-
messages = []
70-
concatenate_chat = chat_history + messages
71-
72-
last_user_index = max(i for i, d in enumerate(concatenate_chat) if d["role"] == "user")
73-
feedback_content = concatenate_chat[last_user_index]["content"]
74-
feedback_history = concatenate_chat[:last_user_index]
75-
76-
feedback_req = APIFeedbackRequest(
77-
user_id=add_req.user_id,
78-
session_id=add_req.session_id,
79-
task_id=add_req.task_id,
80-
history=feedback_history,
81-
feedback_content=feedback_content,
82-
writable_cube_ids=add_req.writable_cube_ids,
83-
async_mode=add_req.async_mode,
84-
)
85-
process_record = cube_view.feedback_memories(feedback_req)
71+
try:
72+
messages = add_req.messages
73+
_check_messages(messages)
8674

87-
self.logger.info(
88-
f"[FeedbackHandler] Final feedback results count={len(process_record)}"
89-
)
75+
chat_history = add_req.chat_history if add_req.chat_history else []
76+
concatenate_chat = chat_history + messages
9077

91-
return MemoryResponse(
92-
message="Memory feedback successfully",
93-
data=[process_record],
94-
)
78+
last_user_index = max(
79+
i for i, d in enumerate(concatenate_chat) if d["role"] == "user"
80+
)
81+
feedback_content = concatenate_chat[last_user_index]["content"]
82+
feedback_history = concatenate_chat[:last_user_index]
83+
84+
feedback_req = APIFeedbackRequest(
85+
user_id=add_req.user_id,
86+
session_id=add_req.session_id,
87+
task_id=add_req.task_id,
88+
history=feedback_history,
89+
feedback_content=feedback_content,
90+
writable_cube_ids=add_req.writable_cube_ids,
91+
async_mode=add_req.async_mode,
92+
)
93+
process_record = cube_view.feedback_memories(feedback_req)
94+
95+
self.logger.info(
96+
f"[ADDFeedbackHandler] Final feedback results count={len(process_record)}"
97+
)
98+
99+
return MemoryResponse(
100+
message="Memory feedback successfully",
101+
data=[process_record],
102+
)
103+
except Exception as e:
104+
self.logger.warning(f"[ADDFeedbackHandler] Running error: {e}")
95105

96106
results = cube_view.add_memories(add_req)
97107

src/memos/api/handlers/chat_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
142142

143143
# Step 2: Build system prompt
144144
system_prompt = self._build_system_prompt(
145-
filtered_memories, search_response.data["pref_string"], chat_req.system_prompt
145+
filtered_memories,
146+
search_response.data.get("pref_string", ""),
147+
chat_req.system_prompt,
146148
)
147149

148150
# Prepare message history
@@ -257,7 +259,7 @@ def generate_chat_response() -> Generator[str, None, None]:
257259
# Step 2: Build system prompt with memories
258260
system_prompt = self._build_system_prompt(
259261
filtered_memories,
260-
search_response.data["pref_string"],
262+
search_response.data.get("pref_string", ""),
261263
chat_req.system_prompt,
262264
)
263265

@@ -449,7 +451,7 @@ def generate_chat_response() -> Generator[str, None, None]:
449451

450452
# Step 2: Build system prompt with memories
451453
system_prompt = self._build_enhance_system_prompt(
452-
filtered_memories, search_response.data["pref_string"]
454+
filtered_memories, search_response.data.get("pref_string", "")
453455
)
454456

455457
# Prepare messages

src/memos/api/handlers/formatters_handler.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,37 @@ def post_process_pref_mem(
9090
memories_result["pref_note"] = pref_note
9191

9292
return memories_result
93+
94+
95+
def post_process_textual_mem(
96+
memories_result: dict[str, Any],
97+
text_formatted_mem: list[dict[str, Any]],
98+
mem_cube_id: str,
99+
) -> dict[str, Any]:
100+
"""
101+
Post-process text and tool memory results.
102+
"""
103+
fact_mem = [
104+
mem
105+
for mem in text_formatted_mem
106+
if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
107+
]
108+
tool_mem = [
109+
mem
110+
for mem in text_formatted_mem
111+
if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
112+
]
113+
114+
memories_result["text_mem"].append(
115+
{
116+
"cube_id": mem_cube_id,
117+
"memories": fact_mem,
118+
}
119+
)
120+
memories_result["tool_mem"].append(
121+
{
122+
"cube_id": mem_cube_id,
123+
"memories": tool_mem,
124+
}
125+
)
126+
return memories_result

src/memos/api/product_models.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# Import message types from core types module
88
from memos.log import get_logger
9-
from memos.types import MessageDict, PermissionDict, SearchMode
9+
from memos.types import MessageList, MessagesType, PermissionDict, SearchMode
1010

1111

1212
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ class Message(BaseModel):
5656

5757
class MemoryCreate(BaseRequest):
5858
user_id: str = Field(..., description="User ID")
59-
messages: list | None = Field(None, description="List of messages to store.")
59+
messages: MessageList | None = Field(None, description="List of messages to store.")
6060
memory_content: str | None = Field(None, description="Content to store as memory")
6161
doc_path: str | None = Field(None, description="Path to document to store")
6262
mem_cube_id: str | None = Field(None, description="ID of the memory cube")
@@ -83,7 +83,7 @@ class ChatRequest(BaseRequest):
8383
writable_cube_ids: list[str] | None = Field(
8484
None, description="List of cube IDs user can write for multi-cube chat"
8585
)
86-
history: list | None = Field(None, description="Chat history")
86+
history: MessageList | None = Field(None, description="Chat history")
8787
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
8888
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
8989
top_k: int = Field(10, description="Number of results to return")
@@ -165,7 +165,7 @@ class ChatCompleteRequest(BaseRequest):
165165
user_id: str = Field(..., description="User ID")
166166
query: str = Field(..., description="Chat query message")
167167
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
168-
history: list | None = Field(None, description="Chat history")
168+
history: MessageList | None = Field(None, description="Chat history")
169169
internet_search: bool = Field(False, description="Whether to use internet search")
170170
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
171171
top_k: int = Field(10, description="Number of results to return")
@@ -251,7 +251,7 @@ class MemoryCreateRequest(BaseRequest):
251251
"""Request model for creating memories."""
252252

253253
user_id: str = Field(..., description="User ID")
254-
messages: str | list | None = Field(None, description="List of messages to store.")
254+
messages: str | MessagesType | None = Field(None, description="List of messages to store.")
255255
memory_content: str | None = Field(None, description="Memory content to store")
256256
doc_path: str | None = Field(None, description="Path to document to store")
257257
mem_cube_id: str | None = Field(None, description="Cube ID")
@@ -326,6 +326,21 @@ class APISearchRequest(BaseRequest):
326326
),
327327
)
328328

329+
search_tool_memory: bool = Field(
330+
True,
331+
description=(
332+
"Whether to retrieve tool memories along with general memories. "
333+
"If enabled, the system will automatically recall tool memories "
334+
"relevant to the query. Default: True."
335+
),
336+
)
337+
338+
tool_mem_top_k: int = Field(
339+
6,
340+
ge=0,
341+
description="Number of tool memories to retrieve (top-K). Default: 6.",
342+
)
343+
329344
# ==== Filter conditions ====
330345
# TODO: maybe add detailed description later
331346
filter: dict[str, Any] | None = Field(
@@ -360,7 +375,7 @@ class APISearchRequest(BaseRequest):
360375
)
361376

362377
# ==== Context ====
363-
chat_history: list | None = Field(
378+
chat_history: MessageList | None = Field(
364379
None,
365380
description=(
366381
"Historical chat messages used internally by algorithms. "
@@ -490,7 +505,7 @@ class APIADDRequest(BaseRequest):
490505
)
491506

492507
# ==== Input content ====
493-
messages: str | list | None = Field(
508+
messages: MessagesType | None = Field(
494509
None,
495510
description=(
496511
"List of messages to store. Supports: "
@@ -506,7 +521,7 @@ class APIADDRequest(BaseRequest):
506521
)
507522

508523
# ==== Chat history ====
509-
chat_history: list | None = Field(
524+
chat_history: MessageList | None = Field(
510525
None,
511526
description=(
512527
"Historical chat messages used internally by algorithms. "
@@ -636,21 +651,20 @@ class APIFeedbackRequest(BaseRequest):
636651
"default_session", description="Session ID for soft-filtering memories"
637652
)
638653
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
639-
history: list[MessageDict] | None = Field(..., description="Chat history")
654+
history: MessageList | None = Field(..., description="Chat history")
640655
retrieved_memory_ids: list[str] | None = Field(
641656
None, description="Retrieved memory ids at last turn"
642657
)
643658
feedback_content: str | None = Field(..., description="Feedback content to process")
644659
feedback_time: str | None = Field(None, description="Feedback time")
645-
# ==== Multi-cube writing ====
646660
writable_cube_ids: list[str] | None = Field(
647661
None, description="List of cube IDs user can write for multi-cube add"
648662
)
649663
async_mode: Literal["sync", "async"] = Field(
650664
"async", description="feedback mode: sync or async"
651665
)
652666
corrected_answer: bool = Field(False, description="Whether need return corrected answer")
653-
# ==== Backward compatibility ====
667+
# ==== mem_cube_id is NOT enabled====
654668
mem_cube_id: str | None = Field(
655669
None,
656670
description=(
@@ -671,7 +685,7 @@ class APIChatCompleteRequest(BaseRequest):
671685
writable_cube_ids: list[str] | None = Field(
672686
None, description="List of cube IDs user can write for multi-cube chat"
673687
)
674-
history: list | None = Field(None, description="Chat history")
688+
history: MessageList | None = Field(None, description="Chat history")
675689
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
676690
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
677691
top_k: int = Field(10, description="Number of results to return")
@@ -740,7 +754,7 @@ class SuggestionRequest(BaseRequest):
740754
user_id: str = Field(..., description="User ID")
741755
mem_cube_id: str = Field(..., description="Cube ID")
742756
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
743-
message: list | None = Field(None, description="List of messages to store.")
757+
message: MessagesType | None = Field(None, description="List of messages to store.")
744758

745759

746760
# ─── MemOS Client Response Models ──────────────────────────────────────────────

src/memos/configs/embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig):
1212
embedding_dims: int | None = Field(
1313
default=None, description="Number of dimensions for the embedding"
1414
)
15+
max_tokens: int | None = Field(
16+
default=8192,
17+
description="Maximum number of tokens per text. Texts exceeding this limit will be automatically truncated. Set to None to disable truncation.",
18+
)
1519
headers_extra: dict[str, Any] | None = Field(
1620
default=None,
1721
description="Extra headers for the embedding model, only for universal_api backend",

src/memos/embedders/ark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4949
MultimodalEmbeddingContentPartTextParam,
5050
)
5151

52+
# Truncate texts if max_tokens is configured
53+
texts = self._truncate_texts(texts)
54+
5255
if self.config.multi_modal:
5356
texts_input = [
5457
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts

0 commit comments

Comments
 (0)