|
1 | | -import json |
2 | 1 | import os |
3 | 2 | import traceback |
4 | 3 |
|
5 | 4 | from concurrent.futures import ThreadPoolExecutor |
6 | | -from typing import Any |
| 5 | +from typing import TYPE_CHECKING, Any |
7 | 6 |
|
8 | 7 | from fastapi import APIRouter, HTTPException |
9 | 8 |
|
|
33 | 32 | from memos.mem_scheduler.orm_modules.base_model import BaseDBManager |
34 | 33 | from memos.mem_scheduler.scheduler_factory import SchedulerFactory |
35 | 34 | from memos.mem_scheduler.schemas.general_schemas import ( |
36 | | - API_MIX_SEARCH_LABEL, |
37 | 35 | SearchMode, |
38 | 36 | ) |
39 | | -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
40 | | -from memos.mem_scheduler.utils.db_utils import get_utc_now |
41 | 37 | from memos.memories.textual.prefer_text_memory.config import ( |
42 | 38 | AdderConfigFactory, |
43 | 39 | ExtractorConfigFactory, |
|
54 | 50 | ) |
55 | 51 | from memos.reranker.factory import RerankerFactory |
56 | 52 | from memos.templates.instruction_completion import instruct_completion |
| 53 | + |
| 54 | + |
| 55 | +if TYPE_CHECKING: |
| 56 | + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler |
57 | 57 | from memos.types import MOSSearchResult, UserContext |
58 | 58 | from memos.vec_dbs.factory import VecDBFactory |
59 | 59 |
|
@@ -208,36 +208,53 @@ def init_server(): |
208 | 208 | online_bot=False, |
209 | 209 | ) |
210 | 210 |
|
| 211 | + naive_mem_cube = NaiveMemCube( |
| 212 | + llm=llm, |
| 213 | + embedder=embedder, |
| 214 | + mem_reader=mem_reader, |
| 215 | + graph_db=graph_db, |
| 216 | + reranker=reranker, |
| 217 | + internet_retriever=internet_retriever, |
| 218 | + memory_manager=memory_manager, |
| 219 | + default_cube_config=default_cube_config, |
| 220 | + vector_db=vector_db, |
| 221 | + pref_extractor=pref_extractor, |
| 222 | + pref_adder=pref_adder, |
| 223 | + pref_retriever=pref_retriever, |
| 224 | + ) |
| 225 | + |
211 | 226 | # Initialize Scheduler |
212 | 227 | scheduler_config_dict = APIConfig.get_scheduler_config() |
213 | 228 | scheduler_config = SchedulerConfigFactory( |
214 | 229 | backend="optimized_scheduler", config=scheduler_config_dict |
215 | 230 | ) |
216 | | - mem_scheduler = SchedulerFactory.from_config(scheduler_config) |
| 231 | + mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) |
217 | 232 | mem_scheduler.initialize_modules( |
218 | 233 | chat_llm=llm, |
219 | 234 | process_llm=mem_reader.llm, |
220 | 235 | db_engine=BaseDBManager.create_default_sqlite_engine(), |
221 | 236 | ) |
| 237 | + mem_scheduler.current_mem_cube = naive_mem_cube |
222 | 238 | mem_scheduler.start() |
223 | 239 |
|
224 | 240 | # Initialize SchedulerAPIModule |
225 | 241 | api_module = mem_scheduler.api_module |
226 | 242 |
|
227 | | - naive_mem_cube = NaiveMemCube( |
228 | | - llm=llm, |
229 | | - embedder=embedder, |
230 | | - mem_reader=mem_reader, |
231 | | - graph_db=graph_db, |
232 | | - reranker=reranker, |
233 | | - internet_retriever=internet_retriever, |
234 | | - memory_manager=memory_manager, |
235 | | - default_cube_config=default_cube_config, |
236 | | - vector_db=vector_db, |
237 | | - pref_extractor=pref_extractor, |
238 | | - pref_adder=pref_adder, |
239 | | - pref_retriever=pref_retriever, |
| 243 | + # Initialize Scheduler |
| 244 | + scheduler_config_dict = APIConfig.get_scheduler_config() |
| 245 | + scheduler_config = SchedulerConfigFactory( |
| 246 | + backend="optimized_scheduler", config=scheduler_config_dict |
| 247 | + ) |
| 248 | + mem_scheduler = SchedulerFactory.from_config(scheduler_config) |
| 249 | + mem_scheduler.initialize_modules( |
| 250 | + chat_llm=llm, |
| 251 | + process_llm=mem_reader.llm, |
| 252 | + db_engine=BaseDBManager.create_default_sqlite_engine(), |
240 | 253 | ) |
| 254 | + mem_scheduler.start() |
| 255 | + |
| 256 | + # Initialize SchedulerAPIModule |
| 257 | + api_module = mem_scheduler.api_module |
241 | 258 |
|
242 | 259 | return ( |
243 | 260 | graph_db, |
@@ -398,96 +415,12 @@ def mix_search_memories( |
398 | 415 | """ |
399 | 416 | Mix search memories: fast search + async fine search |
400 | 417 | """ |
401 | | - # Get fast memories first |
402 | | - fast_memories = fast_search_memories(search_req, user_context) |
403 | | - |
404 | | - # Check if scheduler and dispatcher are available for async execution |
405 | | - if mem_scheduler and hasattr(mem_scheduler, "dispatcher") and mem_scheduler.dispatcher: |
406 | | - try: |
407 | | - # Create message for async fine search |
408 | | - message_content = { |
409 | | - "search_req": { |
410 | | - "query": search_req.query, |
411 | | - "user_id": search_req.user_id, |
412 | | - "session_id": search_req.session_id, |
413 | | - "top_k": search_req.top_k, |
414 | | - "internet_search": search_req.internet_search, |
415 | | - "moscube": search_req.moscube, |
416 | | - "chat_history": search_req.chat_history, |
417 | | - }, |
418 | | - "user_context": {"mem_cube_id": user_context.mem_cube_id}, |
419 | | - } |
420 | | - |
421 | | - message = ScheduleMessageItem( |
422 | | - item_id=f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}", |
423 | | - user_id=search_req.user_id, |
424 | | - mem_cube_id=user_context.mem_cube_id, |
425 | | - label=API_MIX_SEARCH_LABEL, |
426 | | - mem_cube=naive_mem_cube, |
427 | | - content=json.dumps(message_content), |
428 | | - timestamp=get_utc_now(), |
429 | | - ) |
430 | 418 |
|
431 | | - # Submit async task |
432 | | - mem_scheduler.dispatcher.submit_message(message) |
433 | | - logger.info(f"Submitted async fine search task for user {search_req.user_id}") |
434 | | - |
435 | | - # Try to get pre-computed fine memories if available |
436 | | - try: |
437 | | - pre_fine_memories = api_module.get_pre_fine_memories( |
438 | | - user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id |
439 | | - ) |
440 | | - if pre_fine_memories: |
441 | | - # Merge fast and pre-computed fine memories |
442 | | - all_memories = fast_memories + pre_fine_memories |
443 | | - # Remove duplicates based on content |
444 | | - seen_contents = set() |
445 | | - unique_memories = [] |
446 | | - for memory in all_memories: |
447 | | - content_key = memory.get("content", "") |
448 | | - if content_key not in seen_contents: |
449 | | - seen_contents.add(content_key) |
450 | | - unique_memories.append(memory) |
451 | | - return unique_memories |
452 | | - except Exception as e: |
453 | | - logger.warning(f"Failed to get pre-computed fine memories: {e}") |
454 | | - |
455 | | - except Exception as e: |
456 | | - logger.error(f"Failed to submit async fine search task: {e}") |
457 | | - # Fall back to synchronous execution |
458 | | - |
459 | | - # Fallback: synchronous fine search |
460 | | - try: |
461 | | - fine_memories = fine_search_memories(search_req, user_context) |
462 | | - |
463 | | - # Merge fast and fine memories |
464 | | - all_memories = fast_memories + fine_memories |
465 | | - |
466 | | - # Remove duplicates based on content |
467 | | - seen_contents = set() |
468 | | - unique_memories = [] |
469 | | - for memory in all_memories: |
470 | | - content_key = memory.get("content", "") |
471 | | - if content_key not in seen_contents: |
472 | | - seen_contents.add(content_key) |
473 | | - unique_memories.append(memory) |
474 | | - |
475 | | - # Sync search data to Redis |
476 | | - try: |
477 | | - api_module.sync_search_data( |
478 | | - user_id=search_req.user_id, |
479 | | - mem_cube_id=user_context.mem_cube_id, |
480 | | - query=search_req.query, |
481 | | - formatted_memories=unique_memories, |
482 | | - ) |
483 | | - except Exception as e: |
484 | | - logger.error(f"Failed to sync search data: {e}") |
485 | | - |
486 | | - return unique_memories |
487 | | - |
488 | | - except Exception as e: |
489 | | - logger.error(f"Fine search failed: {e}") |
490 | | - return fast_memories |
| 419 | + formatted_memories = mem_scheduler.mix_search_memories( |
| 420 | + search_req=search_req, |
| 421 | + user_context=user_context, |
| 422 | + ) |
| 423 | + return formatted_memories |
491 | 424 |
|
492 | 425 |
|
493 | 426 | def fine_search_memories( |
|
0 commit comments