|
| 1 | +""" |
| 2 | +Add handler for memory addition functionality (Class-based version). |
| 3 | +
|
| 4 | +This module provides a class-based implementation of add handlers, |
| 5 | +using dependency injection for better modularity and testability. |
| 6 | +""" |
| 7 | + |
| 8 | +import json |
| 9 | +import os |
| 10 | + |
| 11 | +from datetime import datetime |
| 12 | + |
| 13 | +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies |
| 14 | +from memos.api.product_models import APIADDRequest, MemoryResponse |
| 15 | +from memos.context.context import ContextThreadPoolExecutor |
| 16 | +from memos.mem_scheduler.schemas.general_schemas import ( |
| 17 | + ADD_LABEL, |
| 18 | + MEM_READ_LABEL, |
| 19 | + PREF_ADD_LABEL, |
| 20 | +) |
| 21 | +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
| 22 | +from memos.types import UserContext |
| 23 | + |
| 24 | + |
| 25 | +class AddHandler(BaseHandler): |
| 26 | + """ |
| 27 | + Handler for memory addition operations. |
| 28 | +
|
| 29 | + Handles both text and preference memory additions with sync/async support. |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, dependencies: HandlerDependencies): |
| 33 | + """ |
| 34 | + Initialize add handler. |
| 35 | +
|
| 36 | + Args: |
| 37 | + dependencies: HandlerDependencies instance |
| 38 | + """ |
| 39 | + super().__init__(dependencies) |
| 40 | + self._validate_dependencies("naive_mem_cube", "mem_reader", "mem_scheduler") |
| 41 | + |
| 42 | + def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: |
| 43 | + """ |
| 44 | + Main handler for add memories endpoint. |
| 45 | +
|
| 46 | + Orchestrates the addition of both text and preference memories, |
| 47 | + supporting concurrent processing. |
| 48 | +
|
| 49 | + Args: |
| 50 | + add_req: Add memory request |
| 51 | +
|
| 52 | + Returns: |
| 53 | + MemoryResponse with added memory information |
| 54 | + """ |
| 55 | + # Create UserContext object |
| 56 | + user_context = UserContext( |
| 57 | + user_id=add_req.user_id, |
| 58 | + mem_cube_id=add_req.mem_cube_id, |
| 59 | + session_id=add_req.session_id or "default_session", |
| 60 | + ) |
| 61 | + |
| 62 | + self.logger.info(f"Add Req is: {add_req}") |
| 63 | + if (not add_req.messages) and add_req.memory_content: |
| 64 | + add_req.messages = self._convert_content_messsage(add_req.memory_content) |
| 65 | + self.logger.info(f"Converted Add Req content to messages: {add_req.messages}") |
| 66 | + # Process text and preference memories in parallel |
| 67 | + with ContextThreadPoolExecutor(max_workers=2) as executor: |
| 68 | + text_future = executor.submit(self._process_text_mem, add_req, user_context) |
| 69 | + pref_future = executor.submit(self._process_pref_mem, add_req, user_context) |
| 70 | + |
| 71 | + text_response_data = text_future.result() |
| 72 | + pref_response_data = pref_future.result() |
| 73 | + |
| 74 | + self.logger.info(f"add_memories Text response data: {text_response_data}") |
| 75 | + self.logger.info(f"add_memories Pref response data: {pref_response_data}") |
| 76 | + |
| 77 | + return MemoryResponse( |
| 78 | + message="Memory added successfully", |
| 79 | + data=text_response_data + pref_response_data, |
| 80 | + ) |
| 81 | + |
| 82 | + def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]: |
| 83 | + """ |
| 84 | + Convert content string to list of message dictionaries. |
| 85 | +
|
| 86 | + Args: |
| 87 | + content: add content string |
| 88 | +
|
| 89 | + Returns: |
| 90 | + List of message dictionaries |
| 91 | + """ |
| 92 | + messages_list = [ |
| 93 | + { |
| 94 | + "role": "user", |
| 95 | + "content": memory_content, |
| 96 | + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), |
| 97 | + } |
| 98 | + ] |
| 99 | + # for only user-str input and convert message |
| 100 | + return messages_list |
| 101 | + |
| 102 | + def _process_text_mem( |
| 103 | + self, |
| 104 | + add_req: APIADDRequest, |
| 105 | + user_context: UserContext, |
| 106 | + ) -> list[dict[str, str]]: |
| 107 | + """ |
| 108 | + Process and add text memories. |
| 109 | +
|
| 110 | + Extracts memories from messages and adds them to the text memory system. |
| 111 | + Handles both sync and async modes. |
| 112 | +
|
| 113 | + Args: |
| 114 | + add_req: Add memory request |
| 115 | + user_context: User context with IDs |
| 116 | +
|
| 117 | + Returns: |
| 118 | + List of formatted memory responses |
| 119 | + """ |
| 120 | + target_session_id = add_req.session_id or "default_session" |
| 121 | + |
| 122 | + # Determine sync mode |
| 123 | + sync_mode = add_req.async_mode or self._get_sync_mode() |
| 124 | + |
| 125 | + self.logger.info(f"Processing text memory with mode: {sync_mode}") |
| 126 | + |
| 127 | + # Extract memories |
| 128 | + memories_local = self.mem_reader.get_memory( |
| 129 | + [add_req.messages], |
| 130 | + type="chat", |
| 131 | + info={ |
| 132 | + "user_id": add_req.user_id, |
| 133 | + "session_id": target_session_id, |
| 134 | + }, |
| 135 | + mode="fast" if sync_mode == "async" else "fine", |
| 136 | + ) |
| 137 | + flattened_local = [mm for m in memories_local for mm in m] |
| 138 | + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") |
| 139 | + |
| 140 | + # Add memories to text_mem |
| 141 | + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( |
| 142 | + flattened_local, |
| 143 | + user_name=user_context.mem_cube_id, |
| 144 | + ) |
| 145 | + self.logger.info( |
| 146 | + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " |
| 147 | + f"in session {add_req.session_id}: {mem_ids_local}" |
| 148 | + ) |
| 149 | + |
| 150 | + # Schedule async/sync tasks |
| 151 | + self._schedule_memory_tasks( |
| 152 | + add_req=add_req, |
| 153 | + user_context=user_context, |
| 154 | + mem_ids=mem_ids_local, |
| 155 | + sync_mode=sync_mode, |
| 156 | + ) |
| 157 | + |
| 158 | + return [ |
| 159 | + { |
| 160 | + "memory": memory.memory, |
| 161 | + "memory_id": memory_id, |
| 162 | + "memory_type": memory.metadata.memory_type, |
| 163 | + } |
| 164 | + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) |
| 165 | + ] |
| 166 | + |
| 167 | + def _process_pref_mem( |
| 168 | + self, |
| 169 | + add_req: APIADDRequest, |
| 170 | + user_context: UserContext, |
| 171 | + ) -> list[dict[str, str]]: |
| 172 | + """ |
| 173 | + Process and add preference memories. |
| 174 | +
|
| 175 | + Extracts preferences from messages and adds them to the preference memory system. |
| 176 | + Handles both sync and async modes. |
| 177 | +
|
| 178 | + Args: |
| 179 | + add_req: Add memory request |
| 180 | + user_context: User context with IDs |
| 181 | +
|
| 182 | + Returns: |
| 183 | + List of formatted preference responses |
| 184 | + """ |
| 185 | + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": |
| 186 | + return [] |
| 187 | + |
| 188 | + # Determine sync mode |
| 189 | + sync_mode = add_req.async_mode or self._get_sync_mode() |
| 190 | + target_session_id = add_req.session_id or "default_session" |
| 191 | + |
| 192 | + # Follow async behavior: enqueue when async |
| 193 | + if sync_mode == "async": |
| 194 | + try: |
| 195 | + messages_list = [add_req.messages] |
| 196 | + message_item_pref = ScheduleMessageItem( |
| 197 | + user_id=add_req.user_id, |
| 198 | + session_id=target_session_id, |
| 199 | + mem_cube_id=add_req.mem_cube_id, |
| 200 | + mem_cube=self.naive_mem_cube, |
| 201 | + label=PREF_ADD_LABEL, |
| 202 | + content=json.dumps(messages_list), |
| 203 | + timestamp=datetime.utcnow(), |
| 204 | + ) |
| 205 | + self.mem_scheduler.submit_messages(messages=[message_item_pref]) |
| 206 | + self.logger.info("Submitted preference add to scheduler (async mode)") |
| 207 | + except Exception as e: |
| 208 | + self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) |
| 209 | + return [] |
| 210 | + else: |
| 211 | + # Sync mode: process immediately |
| 212 | + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( |
| 213 | + [add_req.messages], |
| 214 | + type="chat", |
| 215 | + info={ |
| 216 | + "user_id": add_req.user_id, |
| 217 | + "session_id": target_session_id, |
| 218 | + "mem_cube_id": add_req.mem_cube_id, |
| 219 | + }, |
| 220 | + ) |
| 221 | + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) |
| 222 | + self.logger.info( |
| 223 | + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " |
| 224 | + f"in session {add_req.session_id}: {pref_ids_local}" |
| 225 | + ) |
| 226 | + return [ |
| 227 | + { |
| 228 | + "memory": memory.memory, |
| 229 | + "memory_id": memory_id, |
| 230 | + "memory_type": memory.metadata.preference_type, |
| 231 | + } |
| 232 | + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) |
| 233 | + ] |
| 234 | + |
| 235 | + def _get_sync_mode(self) -> str: |
| 236 | + """ |
| 237 | + Get synchronization mode from memory cube. |
| 238 | +
|
| 239 | + Returns: |
| 240 | + Sync mode string ("sync" or "async") |
| 241 | + """ |
| 242 | + try: |
| 243 | + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") |
| 244 | + except Exception: |
| 245 | + return "sync" |
| 246 | + |
| 247 | + def _schedule_memory_tasks( |
| 248 | + self, |
| 249 | + add_req: APIADDRequest, |
| 250 | + user_context: UserContext, |
| 251 | + mem_ids: list[str], |
| 252 | + sync_mode: str, |
| 253 | + ) -> None: |
| 254 | + """ |
| 255 | + Schedule memory processing tasks based on sync mode. |
| 256 | +
|
| 257 | + Args: |
| 258 | + add_req: Add memory request |
| 259 | + user_context: User context |
| 260 | + mem_ids: List of memory IDs |
| 261 | + sync_mode: Synchronization mode |
| 262 | + """ |
| 263 | + target_session_id = add_req.session_id or "default_session" |
| 264 | + |
| 265 | + if sync_mode == "async": |
| 266 | + # Async mode: submit MEM_READ_LABEL task |
| 267 | + try: |
| 268 | + message_item_read = ScheduleMessageItem( |
| 269 | + user_id=add_req.user_id, |
| 270 | + session_id=target_session_id, |
| 271 | + mem_cube_id=add_req.mem_cube_id, |
| 272 | + mem_cube=self.naive_mem_cube, |
| 273 | + label=MEM_READ_LABEL, |
| 274 | + content=json.dumps(mem_ids), |
| 275 | + timestamp=datetime.utcnow(), |
| 276 | + user_name=add_req.mem_cube_id, |
| 277 | + ) |
| 278 | + self.mem_scheduler.submit_messages(messages=[message_item_read]) |
| 279 | + self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") |
| 280 | + except Exception as e: |
| 281 | + self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) |
| 282 | + else: |
| 283 | + # Sync mode: submit ADD_LABEL task |
| 284 | + message_item_add = ScheduleMessageItem( |
| 285 | + user_id=add_req.user_id, |
| 286 | + session_id=target_session_id, |
| 287 | + mem_cube_id=add_req.mem_cube_id, |
| 288 | + mem_cube=self.naive_mem_cube, |
| 289 | + label=ADD_LABEL, |
| 290 | + content=json.dumps(mem_ids), |
| 291 | + timestamp=datetime.utcnow(), |
| 292 | + user_name=add_req.mem_cube_id, |
| 293 | + ) |
| 294 | + self.mem_scheduler.submit_messages(messages=[message_item_add]) |
0 commit comments