Skip to content

Commit 1043377

Browse files
authored
feat: abstract CubeView to Add & Search Handler (#498)
* feat: abstract CubeView to Add Handler * feat: add readable and writable memcube-ids * feat: multi-cube search router
1 parent 1adf36e commit 1043377

File tree

7 files changed

+789
-496
lines changed

7 files changed

+789
-496
lines changed

src/memos/api/handlers/add_handler.py

Lines changed: 57 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,13 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8-
import json
9-
import os
10-
118
from datetime import datetime
129

1310
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
1411
from memos.api.product_models import APIADDRequest, MemoryResponse
15-
from memos.context.context import ContextThreadPoolExecutor
16-
from memos.mem_scheduler.schemas.general_schemas import (
17-
ADD_LABEL,
18-
MEM_READ_LABEL,
19-
PREF_ADD_LABEL,
20-
)
21-
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
22-
from memos.types import UserContext
12+
from memos.multi_mem_cube.composite_cube import CompositeCubeView
13+
from memos.multi_mem_cube.single_cube import SingleCubeView
14+
from memos.multi_mem_cube.views import MemCubeView
2315

2416

2517
class AddHandler(BaseHandler):
@@ -52,33 +44,69 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
5244
Returns:
5345
MemoryResponse with added memory information
5446
"""
55-
# Create UserContext object
56-
user_context = UserContext(
57-
user_id=add_req.user_id,
58-
mem_cube_id=add_req.mem_cube_id,
59-
session_id=add_req.session_id or "default_session",
60-
)
47+
self.logger.info(f"[AddHandler] Add Req is: {add_req}")
6148

62-
self.logger.info(f"Add Req is: {add_req}")
63-
if (not add_req.messages) and add_req.memory_content:
49+
if (not add_req.messages) and getattr(add_req, "memory_content", None):
6450
add_req.messages = self._convert_content_messsage(add_req.memory_content)
65-
self.logger.info(f"Converted Add Req content to messages: {add_req.messages}")
66-
# Process text and preference memories in parallel
67-
with ContextThreadPoolExecutor(max_workers=2) as executor:
68-
text_future = executor.submit(self._process_text_mem, add_req, user_context)
69-
pref_future = executor.submit(self._process_pref_mem, add_req, user_context)
51+
self.logger.info(f"[AddHandler] Converted content to messages: {add_req.messages}")
7052

71-
text_response_data = text_future.result()
72-
pref_response_data = pref_future.result()
53+
cube_view = self._build_cube_view(add_req)
7354

74-
self.logger.info(f"add_memories Text response data: {text_response_data}")
75-
self.logger.info(f"add_memories Pref response data: {pref_response_data}")
55+
results = cube_view.add_memories(add_req)
56+
57+
self.logger.info(f"[AddHandler] Final add results count={len(results)}")
7658

7759
return MemoryResponse(
7860
message="Memory added successfully",
79-
data=text_response_data + pref_response_data,
61+
data=results,
8062
)
8163

64+
def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]:
65+
"""
66+
Normalize target cube ids from add_req.
67+
Priority:
68+
1) writable_cube_ids
69+
2) mem_cube_id
70+
3) fallback to user_id
71+
"""
72+
if getattr(add_req, "writable_cube_ids", None):
73+
return list(dict.fromkeys(add_req.writable_cube_ids))
74+
75+
if add_req.mem_cube_id:
76+
return [add_req.mem_cube_id]
77+
78+
return [add_req.user_id]
79+
80+
def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
81+
cube_ids = self._resolve_cube_ids(add_req)
82+
83+
if len(cube_ids) == 1:
84+
cube_id = cube_ids[0]
85+
return SingleCubeView(
86+
cube_id=cube_id,
87+
naive_mem_cube=self.naive_mem_cube,
88+
mem_reader=self.mem_reader,
89+
mem_scheduler=self.mem_scheduler,
90+
logger=self.logger,
91+
searcher=None,
92+
)
93+
else:
94+
single_views = [
95+
SingleCubeView(
96+
cube_id=cube_id,
97+
naive_mem_cube=self.naive_mem_cube,
98+
mem_reader=self.mem_reader,
99+
mem_scheduler=self.mem_scheduler,
100+
logger=self.logger,
101+
searcher=None,
102+
)
103+
for cube_id in cube_ids
104+
]
105+
return CompositeCubeView(
106+
cube_views=single_views,
107+
logger=self.logger,
108+
)
109+
82110
def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]:
83111
"""
84112
Convert content string to list of message dictionaries.
@@ -98,197 +126,3 @@ def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]
98126
]
99127
# for only user-str input and convert message
100128
return messages_list
101-
102-
def _process_text_mem(
103-
self,
104-
add_req: APIADDRequest,
105-
user_context: UserContext,
106-
) -> list[dict[str, str]]:
107-
"""
108-
Process and add text memories.
109-
110-
Extracts memories from messages and adds them to the text memory system.
111-
Handles both sync and async modes.
112-
113-
Args:
114-
add_req: Add memory request
115-
user_context: User context with IDs
116-
117-
Returns:
118-
List of formatted memory responses
119-
"""
120-
target_session_id = add_req.session_id or "default_session"
121-
122-
# Determine sync mode
123-
sync_mode = add_req.async_mode or self._get_sync_mode()
124-
125-
self.logger.info(f"Processing text memory with mode: {sync_mode}")
126-
127-
# Extract memories
128-
memories_local = self.mem_reader.get_memory(
129-
[add_req.messages],
130-
type="chat",
131-
info={
132-
"user_id": add_req.user_id,
133-
"session_id": target_session_id,
134-
},
135-
mode="fast" if sync_mode == "async" else "fine",
136-
)
137-
flattened_local = [mm for m in memories_local for mm in m]
138-
self.logger.info(f"Memory extraction completed for user {add_req.user_id}")
139-
140-
# Add memories to text_mem
141-
mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add(
142-
flattened_local,
143-
user_name=user_context.mem_cube_id,
144-
)
145-
self.logger.info(
146-
f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
147-
f"in session {add_req.session_id}: {mem_ids_local}"
148-
)
149-
150-
# Schedule async/sync tasks
151-
self._schedule_memory_tasks(
152-
add_req=add_req,
153-
user_context=user_context,
154-
mem_ids=mem_ids_local,
155-
sync_mode=sync_mode,
156-
)
157-
158-
return [
159-
{
160-
"memory": memory.memory,
161-
"memory_id": memory_id,
162-
"memory_type": memory.metadata.memory_type,
163-
}
164-
for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False)
165-
]
166-
167-
def _process_pref_mem(
168-
self,
169-
add_req: APIADDRequest,
170-
user_context: UserContext,
171-
) -> list[dict[str, str]]:
172-
"""
173-
Process and add preference memories.
174-
175-
Extracts preferences from messages and adds them to the preference memory system.
176-
Handles both sync and async modes.
177-
178-
Args:
179-
add_req: Add memory request
180-
user_context: User context with IDs
181-
182-
Returns:
183-
List of formatted preference responses
184-
"""
185-
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
186-
return []
187-
188-
# Determine sync mode
189-
sync_mode = add_req.async_mode or self._get_sync_mode()
190-
target_session_id = add_req.session_id or "default_session"
191-
192-
# Follow async behavior: enqueue when async
193-
if sync_mode == "async":
194-
try:
195-
messages_list = [add_req.messages]
196-
message_item_pref = ScheduleMessageItem(
197-
user_id=add_req.user_id,
198-
session_id=target_session_id,
199-
mem_cube_id=add_req.mem_cube_id,
200-
mem_cube=self.naive_mem_cube,
201-
label=PREF_ADD_LABEL,
202-
content=json.dumps(messages_list),
203-
timestamp=datetime.utcnow(),
204-
)
205-
self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref])
206-
self.logger.info("Submitted preference add to scheduler (async mode)")
207-
except Exception as e:
208-
self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True)
209-
return []
210-
else:
211-
# Sync mode: process immediately
212-
pref_memories_local = self.naive_mem_cube.pref_mem.get_memory(
213-
[add_req.messages],
214-
type="chat",
215-
info={
216-
"user_id": add_req.user_id,
217-
"session_id": target_session_id,
218-
"mem_cube_id": add_req.mem_cube_id,
219-
},
220-
)
221-
pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local)
222-
self.logger.info(
223-
f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
224-
f"in session {add_req.session_id}: {pref_ids_local}"
225-
)
226-
return [
227-
{
228-
"memory": memory.memory,
229-
"memory_id": memory_id,
230-
"memory_type": memory.metadata.preference_type,
231-
}
232-
for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
233-
]
234-
235-
def _get_sync_mode(self) -> str:
236-
"""
237-
Get synchronization mode from memory cube.
238-
239-
Returns:
240-
Sync mode string ("sync" or "async")
241-
"""
242-
try:
243-
return getattr(self.naive_mem_cube.text_mem, "mode", "sync")
244-
except Exception:
245-
return "sync"
246-
247-
def _schedule_memory_tasks(
248-
self,
249-
add_req: APIADDRequest,
250-
user_context: UserContext,
251-
mem_ids: list[str],
252-
sync_mode: str,
253-
) -> None:
254-
"""
255-
Schedule memory processing tasks based on sync mode.
256-
257-
Args:
258-
add_req: Add memory request
259-
user_context: User context
260-
mem_ids: List of memory IDs
261-
sync_mode: Synchronization mode
262-
"""
263-
target_session_id = add_req.session_id or "default_session"
264-
265-
if sync_mode == "async":
266-
# Async mode: submit MEM_READ_LABEL task
267-
try:
268-
message_item_read = ScheduleMessageItem(
269-
user_id=add_req.user_id,
270-
session_id=target_session_id,
271-
mem_cube_id=add_req.mem_cube_id,
272-
mem_cube=self.naive_mem_cube,
273-
label=MEM_READ_LABEL,
274-
content=json.dumps(mem_ids),
275-
timestamp=datetime.utcnow(),
276-
user_name=add_req.mem_cube_id,
277-
)
278-
self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read])
279-
self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}")
280-
except Exception as e:
281-
self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True)
282-
else:
283-
# Sync mode: submit ADD_LABEL task
284-
message_item_add = ScheduleMessageItem(
285-
user_id=add_req.user_id,
286-
session_id=target_session_id,
287-
mem_cube_id=add_req.mem_cube_id,
288-
mem_cube=self.naive_mem_cube,
289-
label=ADD_LABEL,
290-
content=json.dumps(mem_ids),
291-
timestamp=datetime.utcnow(),
292-
user_name=add_req.mem_cube_id,
293-
)
294-
self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add])

0 commit comments

Comments
 (0)