5555 ExtractorFactory ,
5656 RetrieverFactory ,
5757)
58+ from memos .memories .textual .simple_preference import SimplePreferenceTextMemory
59+ from memos .memories .textual .simple_tree import SimpleTreeTextMemory
5860from memos .memories .textual .tree_text_memory .organize .manager import MemoryManager
5961from memos .memories .textual .tree_text_memory .retrieve .internet_retriever_factory import (
6062 InternetRetrieverFactory ,
@@ -195,18 +197,43 @@ def init_server():
195197 internet_retriever = InternetRetrieverFactory .from_config (
196198 internet_retriever_config , embedder = embedder
197199 )
200+
201+ # Initialize memory manager
202+ memory_manager = MemoryManager (
203+ graph_db ,
204+ embedder ,
205+ llm ,
206+ memory_size = _get_default_memory_size (default_cube_config ),
207+ is_reorganize = getattr (default_cube_config .text_mem .config , "reorganize" , False ),
208+ )
209+
210+ # Initialize text memory
211+ text_mem = SimpleTreeTextMemory (
212+ llm = llm ,
213+ embedder = embedder ,
214+ mem_reader = mem_reader ,
215+ graph_db = graph_db ,
216+ reranker = reranker ,
217+ memory_manager = memory_manager ,
218+ config = default_cube_config .text_mem .config ,
219+ internet_retriever = internet_retriever ,
220+ )
221+
198222 pref_extractor = ExtractorFactory .from_config (
199223 config_factory = pref_extractor_config ,
200224 llm_provider = llm ,
201225 embedder = embedder ,
202226 vector_db = vector_db ,
203227 )
228+
204229 pref_adder = AdderFactory .from_config (
205230 config_factory = pref_adder_config ,
206231 llm_provider = llm ,
207232 embedder = embedder ,
208233 vector_db = vector_db ,
234+ text_mem = text_mem ,
209235 )
236+
210237 pref_retriever = RetrieverFactory .from_config (
211238 config_factory = pref_retriever_config ,
212239 llm_provider = llm ,
@@ -215,33 +242,29 @@ def init_server():
215242 vector_db = vector_db ,
216243 )
217244
218- # Initialize memory manager
219- memory_manager = MemoryManager (
220- graph_db ,
221- embedder ,
222- llm ,
223- memory_size = _get_default_memory_size (default_cube_config ),
224- is_reorganize = getattr (default_cube_config .text_mem .config , "reorganize" , False ),
245+ # Initialize preference memory
246+ pref_mem = SimplePreferenceTextMemory (
247+ extractor_llm = llm ,
248+ vector_db = vector_db ,
249+ embedder = embedder ,
250+ reranker = reranker ,
251+ extractor = pref_extractor ,
252+ adder = pref_adder ,
253+ retriever = pref_retriever ,
225254 )
255+
226256 mos_server = MOSServer (
227257 mem_reader = mem_reader ,
228258 llm = llm ,
229259 online_bot = False ,
230260 )
231261
262+ # Create MemCube with pre-initialized memory instances
232263 naive_mem_cube = NaiveMemCube (
233- llm = llm ,
234- embedder = embedder ,
235- mem_reader = mem_reader ,
236- graph_db = graph_db ,
237- reranker = reranker ,
238- internet_retriever = internet_retriever ,
239- memory_manager = memory_manager ,
240- default_cube_config = default_cube_config ,
241- vector_db = vector_db ,
242- pref_extractor = pref_extractor ,
243- pref_adder = pref_adder ,
244- pref_retriever = pref_retriever ,
264+ text_mem = text_mem ,
265+ pref_mem = pref_mem ,
266+ act_mem = None ,
267+ para_mem = None ,
245268 )
246269
247270 # Initialize Scheduler
@@ -279,6 +302,8 @@ def init_server():
279302 pref_extractor ,
280303 pref_adder ,
281304 pref_retriever ,
305+ text_mem ,
306+ pref_mem ,
282307 )
283308
284309
@@ -300,6 +325,8 @@ def init_server():
300325 pref_extractor ,
301326 pref_adder ,
302327 pref_retriever ,
328+ text_mem ,
329+ pref_mem ,
303330) = init_server ()
304331
305332
@@ -361,36 +388,46 @@ def search_memories(search_req: APISearchRequest):
361388 search_mode = search_req .mode
362389
363390 def _search_text ():
364- if search_mode == SearchMode .FAST :
365- formatted_memories = fast_search_memories (
366- search_req = search_req , user_context = user_context
367- )
368- elif search_mode == SearchMode .FINE :
369- formatted_memories = fine_search_memories (
370- search_req = search_req , user_context = user_context
371- )
372- elif search_mode == SearchMode .MIXTURE :
373- formatted_memories = mix_search_memories (
374- search_req = search_req , user_context = user_context
375- )
376- else :
377- logger .error (f"Unsupported search mode: { search_mode } " )
378- raise HTTPException (status_code = 400 , detail = f"Unsupported search mode: { search_mode } " )
379- return formatted_memories
391+ try :
392+ if search_mode == SearchMode .FAST :
393+ formatted_memories = fast_search_memories (
394+ search_req = search_req , user_context = user_context
395+ )
396+ elif search_mode == SearchMode .FINE :
397+ formatted_memories = fine_search_memories (
398+ search_req = search_req , user_context = user_context
399+ )
400+ elif search_mode == SearchMode .MIXTURE :
401+ formatted_memories = mix_search_memories (
402+ search_req = search_req , user_context = user_context
403+ )
404+ else :
405+ logger .error (f"Unsupported search mode: { search_mode } " )
406+ raise HTTPException (
407+ status_code = 400 , detail = f"Unsupported search mode: { search_mode } "
408+ )
409+ return formatted_memories
410+ except Exception as e :
411+ logger .error ("Error in search_text: %s; traceback: %s" , e , traceback .format_exc ())
412+ return []
380413
381414 def _search_pref ():
382415 if os .getenv ("ENABLE_PREFERENCE_MEMORY" , "false" ).lower () != "true" :
383416 return []
384- results = naive_mem_cube .pref_mem .search (
385- query = search_req .query ,
386- top_k = search_req .pref_top_k ,
387- info = {
388- "user_id" : search_req .user_id ,
389- "session_id" : search_req .session_id ,
390- "chat_history" : search_req .chat_history ,
391- },
392- )
393- return [_format_memory_item (data ) for data in results ]
417+ try :
418+ results = naive_mem_cube .pref_mem .search (
419+ query = search_req .query ,
420+ top_k = search_req .pref_top_k ,
421+ info = {
422+ "user_id" : search_req .user_id ,
423+ "session_id" : search_req .session_id ,
424+ "chat_history" : search_req .chat_history ,
425+ },
426+ )
427+ return [_format_memory_item (data ) for data in results ]
428+ except Exception as e :
429+ logger .error ("Error in _search_pref: %s; traceback: %s" , e , traceback .format_exc ())
430+ return []
394431
395432 with ContextThreadPoolExecutor (max_workers = 2 ) as executor :
396433 text_future = executor .submit (_search_text )
@@ -601,6 +638,7 @@ def _process_pref_mem() -> list[dict[str, str]]:
601638 info = {
602639 "user_id" : add_req .user_id ,
603640 "session_id" : target_session_id ,
641+ "mem_cube_id" : add_req .mem_cube_id ,
604642 },
605643 )
606644 pref_ids_local : list [str ] = naive_mem_cube .pref_mem .add (pref_memories_local )
0 commit comments