diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py new file mode 100644 index 000000000..11f0ebb81 --- /dev/null +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -0,0 +1,62 @@ +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +# Debug: Print scheduler configuration +print("=== Scheduler Configuration Debug ===") +print(f"Scheduler type: {type(mem_scheduler).__name__}") +print(f"Config: {mem_scheduler.config}") +print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") +print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") +print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") + +# Check if Redis queue is connected +if hasattr(mem_scheduler.memos_message_queue, "_is_connected"): + print(f"Redis connected: {mem_scheduler.memos_message_queue._is_connected}") +if hasattr(mem_scheduler.memos_message_queue, "_redis_conn"): + print(f"Redis connection: {mem_scheduler.memos_message_queue._redis_conn}") +print("=====================================\n") + +queue = mem_scheduler.memos_message_queue +queue.clear() + + +# 1. Define a handler function +def my_test_handler(messages: list[ScheduleMessageItem]): + print(f"My test handler received {len(messages)} messages:") + for msg in messages: + print(f" my_test_handler - {msg.item_id}: {msg.content}") + print( + f"{queue._redis_conn.xinfo_groups(queue.stream_name)} qsize: {queue.qsize()} messages:{messages}" + ) + + +# 2. Register the handler +TEST_HANDLER_LABEL = "test_handler" +mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) + +# 3. Create messages +messages_to_send = [ + ScheduleMessageItem( + item_id=f"test_item_{i}", + user_id="test_user", + mem_cube_id="test_mem_cube", + label=TEST_HANDLER_LABEL, + content=f"This is test message {i}", + ) + for i in range(5) +] + +# 5. Submit messages +for mes in messages_to_send: + print(f"Submitting message {mes.item_id} to the scheduler...") + mem_scheduler.submit_messages([mes]) + +# 6. Wait for messages to be processed (limited to 100 checks) +print("Waiting for messages to be consumed (max 100 checks)...") +mem_scheduler.mem_scheduler_wait() + + +# 7. Stop the scheduler +print("Stopping the scheduler...") +mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py deleted file mode 100644 index 664168f62..000000000 --- a/examples/mem_scheduler/memos_w_optimized_scheduler.py +++ /dev/null @@ -1,85 +0,0 @@ -import shutil -import sys - -from pathlib import Path - -from memos_w_scheduler import init_task, show_web_logs - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.main import MOS - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -def run_with_scheduler_init(): - print("==== run_with_automatic_scheduler_init ====") - conversations, questions = init_task() - - # set configs - mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" - ) - - mem_cube_config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml" - ) - - # default local graphdb uri - if AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - - mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key - mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url - - mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user - mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password - mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name - mem_cube_config.text_mem.config.graph_db.config.auto_create = ( - auth_config.graph_db.auto_create - ) - - # Initialization - mos = MOS(mos_config) - - user_id = "user_1" - mos.create_user(user_id) - - mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" - - if Path(mem_cube_name_or_path).exists(): - shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") - - mem_cube = GeneralMemCube(mem_cube_config) - mem_cube.dump(mem_cube_name_or_path) - mos.register_mem_cube( - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id - ) - - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) - - for item in questions: - print("===== Chat Start =====") - query = item["question"] - print(f"Query:\n {query}\n") - response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}\n") - - show_web_logs(mem_scheduler=mos.mem_scheduler) - - mos.mem_scheduler.stop() - - -if __name__ == "__main__": - run_with_scheduler_init() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py deleted file mode 100644 index ed4f721ad..000000000 --- a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py +++ /dev/null @@ -1,87 +0,0 @@ -import json -import shutil -import sys - -from pathlib import Path - -from memos_w_scheduler_for_test import init_task - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) - -# Enable execution from any working directory - -logger = get_logger(__name__) - -if __name__ == "__main__": - # set up data - conversations, questions = init_task() - - # set configs - mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" - ) - - mem_cube_config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml" - ) - - # default local graphdb uri - if AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - - mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key - mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url - - mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user - mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password - mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name - mem_cube_config.text_mem.config.graph_db.config.auto_create = ( - auth_config.graph_db.auto_create - ) - - # Initialization - mos = MOSForTestScheduler(mos_config) - - user_id = "user_1" - mos.create_user(user_id) - - mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" - - if Path(mem_cube_name_or_path).exists(): - shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") - - mem_cube = GeneralMemCube(mem_cube_config) - mem_cube.dump(mem_cube_name_or_path) - mos.register_mem_cube( - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id - ) - - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) - - # Add interfering conversations - file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json") - scene_data = json.load(file_path.open("r", encoding="utf-8")) - mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) - mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) - - for item in questions: - print("===== Chat Start =====") - query = item["question"] - print(f"Query:\n {query}\n") - response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}\n") - - mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index dc196b85a..c523a8667 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -70,13 +70,48 @@ def init_task(): return conversations, questions +def show_web_logs(mem_scheduler: GeneralScheduler): + """Display all web log entries from the scheduler's log queue. + + Args: + mem_scheduler: The scheduler instance containing web logs to display + """ + if mem_scheduler._web_log_message_queue.empty(): + print("Web log queue is currently empty.") + return + + print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) + + # Create a temporary queue to preserve the original queue contents + temp_queue = Queue() + log_count = 0 + + while not mem_scheduler._web_log_message_queue.empty(): + log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() + temp_queue.put(log_item) + log_count += 1 + + # Print log entry details + print(f"\nLog Entry #{log_count}:") + print(f'- "{log_item.label}" log: {log_item}') + + print("-" * 50) + + # Restore items back to the original queue + while not temp_queue.empty(): + mem_scheduler._web_log_message_queue.put(temp_queue.get()) + + print(f"\nTotal {log_count} web log entries displayed.") + print("=" * 110 + "\n") + + def run_with_scheduler_init(): print("==== run_with_automatic_scheduler_init ====") conversations, questions = init_task() # set configs mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" ) mem_cube_config = GeneralMemCubeConfig.from_yaml_file( @@ -118,6 +153,7 @@ def run_with_scheduler_init(): ) mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + mos.mem_scheduler.current_mem_cube = mem_cube for item in questions: print("===== Chat Start =====") @@ -131,40 +167,5 @@ def run_with_scheduler_init(): mos.mem_scheduler.stop() -def show_web_logs(mem_scheduler: GeneralScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - log_count = 0 - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - temp_queue.put(log_item) - log_count += 1 - - # Print log entry details - print(f"\nLog Entry #{log_count}:") - print(f'- "{log_item.label}" log: {log_item}') - - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {log_count} web log entries displayed.") - print("=" * 110 + "\n") - - if __name__ == "__main__": run_with_scheduler_init() diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py index 6faac98af..2e135f127 100644 --- a/examples/mem_scheduler/memos_w_scheduler_for_test.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -1,10 +1,11 @@ import json import shutil import sys -import time from pathlib import Path +from memos_w_scheduler import init_task + from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.configs.mem_scheduler import AuthConfig @@ -15,155 +16,19 @@ FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -def display_memory_cube_stats(mos, user_id, mem_cube_id): - """Display detailed memory cube statistics.""" - print(f"\n๐Ÿ“Š MEMORY CUBE STATISTICS for {mem_cube_id}:") - print("-" * 60) - - mem_cube = mos.mem_cubes.get(mem_cube_id) - if not mem_cube: - print(" โŒ Memory cube not found") - return - - # Text memory stats - if mem_cube.text_mem: - text_mem = mem_cube.text_mem - working_memories = text_mem.get_working_memory() - all_memories = text_mem.get_all() - - print(" ๐Ÿ“ Text Memory:") - print(f" โ€ข Working Memory Items: {len(working_memories)}") - print( - f" โ€ข Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}" - ) - - if working_memories: - print(" โ€ข Working Memory Content Preview:") - for i, mem in enumerate(working_memories[:2]): - content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory - print(f" {i + 1}. {content}") - - # Activation memory stats - if mem_cube.act_mem: - act_mem = mem_cube.act_mem - act_memories = list(act_mem.get_all()) - print(" โšก Activation Memory:") - print(f" โ€ข KV Cache Items: {len(act_memories)}") - if act_memories: - print( - f" โ€ข Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}" - ) - - print("-" * 60) - - -def display_scheduler_status(mos): - """Display current scheduler status and configuration.""" - print("\nโš™๏ธ SCHEDULER STATUS:") - print("-" * 60) - - if not mos.mem_scheduler: - print(" โŒ Memory scheduler not initialized") - return - - scheduler = mos.mem_scheduler - print(f" ๐Ÿ”„ Scheduler Running: {scheduler._running}") - print(f" ๐Ÿ“Š Internal Queue Size: {scheduler.memos_message_queue.qsize()}") - print(f" ๐Ÿงต Parallel Dispatch: {scheduler.enable_parallel_dispatch}") - print(f" ๐Ÿ‘ฅ Max Workers: {scheduler.thread_pool_max_workers}") - print(f" โฑ๏ธ Consume Interval: {scheduler._consume_interval}s") - - if scheduler.monitor: - print(" ๐Ÿ“ˆ Monitor Active: โœ…") - print(f" ๐Ÿ—„๏ธ Database Engine: {'โœ…' if scheduler.db_engine else 'โŒ'}") - - if scheduler.dispatcher: - print(" ๐Ÿš€ Dispatcher Active: โœ…") - print( - f" ๐Ÿ”ง Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}" - ) +sys.path.insert(0, str(BASE_DIR)) - print("-" * 60) - - -def init_task(): - conversations = [ - { - "role": "user", - "content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.", - }, - {"role": "assistant", "content": "Great! Any special care for them?"}, - { - "role": "user", - "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.", - }, - { - "role": "user", - "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.", - }, - { - "role": "user", - "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.", - }, - ] - - questions = [ - # 1. Basic factual recall (simple) - { - "question": "What breed is Max?", - "category": "Pet", - "expected": "golden retriever", - "difficulty": "easy", - }, - # 2. Temporal context (medium) - { - "question": "Where will I live next month?", - "category": "Location", - "expected": "Chicago", - "difficulty": "medium", - }, - # 3. Information correction (hard) - { - "question": "How old is Bella really?", - "category": "Pet", - "expected": "6", - "difficulty": "hard", - "hint": "User corrected the age later", - }, - # 4. Relationship inference (harder) - { - "question": "Why might Whiskers be nervous around my pets?", - "category": "Behavior", - "expected": "Bella chases her sometimes", - "difficulty": "harder", - }, - # 5. Combined medical info (hardest) - { - "question": "Which pets have health considerations?", - "category": "Health", - "expected": "Max needs joint supplements, Bella is allergic to chicken", - "difficulty": "hardest", - "requires": ["combining multiple facts", "ignoring outdated info"], - }, - ] - return conversations, questions +# Enable execution from any working directory +logger = get_logger(__name__) if __name__ == "__main__": - print("๐Ÿš€ Starting Enhanced Memory Scheduler Test...") - print("=" * 80) - # set up data conversations, questions = init_task() # set configs mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" ) mem_cube_config = GeneralMemCubeConfig.from_yaml_file( @@ -186,7 +51,6 @@ def init_task(): ) # Initialization - print("๐Ÿ”ง Initializing MOS with Scheduler...") mos = MOSForTestScheduler(mos_config) user_id = "user_1" @@ -197,15 +61,15 @@ def init_task(): if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) - print(f"๐Ÿ—‘๏ธ {mem_cube_name_or_path} is not empty, and has been removed.") + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") mem_cube = GeneralMemCube(mem_cube_config) mem_cube.dump(mem_cube_name_or_path) mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) + mos.mem_scheduler.current_mem_cube = mem_cube - print("๐Ÿ“š Adding initial conversations...") mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) # Add interfering conversations @@ -214,77 +78,11 @@ def init_task(): mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) - # Display initial status - print("\n๐Ÿ“Š INITIAL SYSTEM STATUS:") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - # Process questions with enhanced monitoring - print(f"\n๐ŸŽฏ Starting Question Processing ({len(questions)} questions)...") - question_start_time = time.time() - - for i, item in enumerate(questions, 1): - print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}") - print(f"๐Ÿ“ Category: {item['category']} | Difficulty: {item['difficulty']}") - print(f"๐ŸŽฏ Expected: {item['expected']}") - if "hint" in item: - print(f"๐Ÿ’ก Hint: {item['hint']}") - if "requires" in item: - print(f"๐Ÿ” Requires: {', '.join(item['requires'])}") - - print(f"\n๐Ÿš€ Processing Query: {item['question']}") - query_start_time = time.time() - - response = mos.chat(query=item["question"], user_id=user_id) - - query_time = time.time() - query_start_time - print(f"โฑ๏ธ Query Processing Time: {query_time:.3f}s") - print(f"๐Ÿค– Response: {response}") - - # Display intermediate status every 2 questions - if i % 2 == 0: - print(f"\n๐Ÿ“Š INTERMEDIATE STATUS (Question {i}):") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - total_processing_time = time.time() - question_start_time - print(f"\nโฑ๏ธ Total Question Processing Time: {total_processing_time:.3f}s") - - # Display final scheduler performance summary - print("\n" + "=" * 80) - print("๐Ÿ“Š FINAL SCHEDULER PERFORMANCE SUMMARY") - print("=" * 80) - - summary = mos.get_scheduler_summary() - print(f"๐Ÿ”ข Total Queries Processed: {summary['total_queries']}") - print(f"โšก Total Scheduler Calls: {summary['total_scheduler_calls']}") - print(f"โฑ๏ธ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s") - print(f"๐Ÿง  Memory Optimizations Applied: {summary['memory_optimization_count']}") - print(f"๐Ÿ”„ Working Memory Updates: {summary['working_memory_updates']}") - print(f"โšก Activation Memory Updates: {summary['activation_memory_updates']}") - print(f"๐Ÿ“ˆ Average Query Processing Time: {summary['average_query_processing_time']:.3f}s") - - # Performance insights - print("\n๐Ÿ’ก PERFORMANCE INSIGHTS:") - if summary["total_scheduler_calls"] > 0: - optimization_rate = ( - summary["memory_optimization_count"] / summary["total_scheduler_calls"] - ) * 100 - print(f" โ€ข Memory Optimization Rate: {optimization_rate:.1f}%") - - if summary["average_scheduler_response_time"] < 0.1: - print(" โ€ข Scheduler Performance: ๐ŸŸข Excellent (< 100ms)") - elif summary["average_scheduler_response_time"] < 0.5: - print(" โ€ข Scheduler Performance: ๐ŸŸก Good (100-500ms)") - else: - print(" โ€ข Scheduler Performance: ๐Ÿ”ด Needs Improvement (> 500ms)") - - # Final system status - print("\n๐Ÿ” FINAL SYSTEM STATUS:") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - print("=" * 80) - print("๐Ÿ Test completed successfully!") + for item in questions: + print("===== Chat Start =====") + query = item["question"] + print(f"Query:\n {query}\n") + response = mos.chat(query=query, user_id=user_id) + print(f"Answer:\n {response}\n") mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py deleted file mode 100644 index bbb57b4ab..000000000 --- a/examples/mem_scheduler/orm_examples.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python3 -""" -ORM Examples for MemScheduler - -This script demonstrates how to use the BaseDBManager's new environment variable loading methods -for MySQL and Redis connections. -""" - -import multiprocessing -import os -import sys - -from pathlib import Path - - -# Add the src directory to the Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from memos.log import get_logger -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager - - -logger = get_logger(__name__) - - -def test_mysql_engine_from_env(): - """Test loading MySQL engine from environment variables""" - print("\n" + "=" * 60) - print("Testing MySQL Engine from Environment Variables") - print("=" * 60) - - try: - # Test loading MySQL engine from current environment variables - mysql_engine = BaseDBManager.load_mysql_engine_from_env() - if mysql_engine is None: - print("โŒ Failed to create MySQL engine - check environment variables") - return - - print(f"โœ… Successfully created MySQL engine: {mysql_engine}") - print(f" Engine URL: {mysql_engine.url}") - - # Test connection - with mysql_engine.connect() as conn: - from sqlalchemy import text - - result = conn.execute(text("SELECT 'MySQL connection test successful' as message")) - message = result.fetchone()[0] - print(f" Connection test: {message}") - - mysql_engine.dispose() - print(" MySQL engine disposed successfully") - - except DatabaseError as e: - print(f"โŒ DatabaseError: {e}") - except Exception as e: - print(f"โŒ Unexpected error: {e}") - - -def test_redis_connection_from_env(): - """Test loading Redis connection from environment variables""" - print("\n" + "=" * 60) - print("Testing Redis Connection from Environment Variables") - print("=" * 60) - - try: - # Test loading Redis connection from current environment variables - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("โŒ Failed to create Redis connection - check environment variables") - return - - print(f"โœ… Successfully created Redis connection: {redis_client}") - - # Test basic Redis operations - redis_client.set("test_key", "Hello from ORM Examples!") - value = redis_client.get("test_key") - print(f" Redis test - Set/Get: {value}") - - # Test Redis info - info = redis_client.info("server") - redis_version = info.get("redis_version", "unknown") - print(f" Redis server version: {redis_version}") - - # Clean up test key - redis_client.delete("test_key") - print(" Test key cleaned up") - - redis_client.close() - print(" Redis connection closed successfully") - - except DatabaseError as e: - print(f"โŒ DatabaseError: {e}") - except Exception as e: - print(f"โŒ Unexpected error: {e}") - - -def test_environment_variables(): - """Test and display current environment variables""" - print("\n" + "=" * 60) - print("Current Environment Variables") - print("=" * 60) - - # MySQL environment variables - mysql_vars = [ - "MYSQL_HOST", - "MYSQL_PORT", - "MYSQL_USERNAME", - "MYSQL_PASSWORD", - "MYSQL_DATABASE", - "MYSQL_CHARSET", - ] - - print("\nMySQL Environment Variables:") - for var in mysql_vars: - value = os.getenv(var, "Not set") - # Mask password for security - if "PASSWORD" in var and value != "Not set": - value = "*" * len(value) - print(f" {var}: {value}") - - # Redis environment variables - redis_vars = [ - "REDIS_HOST", - "REDIS_PORT", - "REDIS_DB", - "REDIS_PASSWORD", - "MEMSCHEDULER_REDIS_HOST", - "MEMSCHEDULER_REDIS_PORT", - "MEMSCHEDULER_REDIS_DB", - "MEMSCHEDULER_REDIS_PASSWORD", - ] - - print("\nRedis Environment Variables:") - for var in redis_vars: - value = os.getenv(var, "Not set") - # Mask password for security - if "PASSWORD" in var and value != "Not set": - value = "*" * len(value) - print(f" {var}: {value}") - - -def test_manual_env_loading(): - """Test loading environment variables manually from .env file""" - print("\n" + "=" * 60) - print("Testing Manual Environment Loading") - print("=" * 60) - - env_file_path = "/Users/travistang/Documents/codes/memos/.env" - - if not os.path.exists(env_file_path): - print(f"โŒ Environment file not found: {env_file_path}") - return - - try: - from dotenv import load_dotenv - - # Load environment variables - load_dotenv(env_file_path) - print(f"โœ… Successfully loaded environment variables from {env_file_path}") - - # Test some key variables - test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"] - for var in test_vars: - value = os.getenv(var, "Not set") - if "KEY" in var and value != "Not set": - value = f"{value[:10]}..." if len(value) > 10 else value - print(f" {var}: {value}") - - except ImportError: - print("โŒ python-dotenv not installed. Install with: pip install python-dotenv") - except Exception as e: - print(f"โŒ Error loading environment file: {e}") - - -def test_redis_lockable_orm_with_list(): - """Test RedisDBManager with list[str] type synchronization""" - print("\n" + "=" * 60) - print("Testing RedisDBManager with list[str]") - print("=" * 60) - - try: - from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager - - # Create a simple list manager instance - list_manager = SimpleListManager(["apple", "banana", "cherry"]) - print(f"Original list manager: {list_manager}") - - # Create RedisDBManager instance - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("โŒ Failed to create Redis connection - check environment variables") - return - - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="test_list_cube", - obj=list_manager, - ) - - # Save to Redis - db_manager.save_to_db(list_manager) - print("โœ… List manager saved to Redis") - - # Load from Redis - loaded_manager = db_manager.load_from_db() - if loaded_manager: - print(f"Loaded list manager: {loaded_manager}") - print(f"Items match: {list_manager.items == loaded_manager.items}") - else: - print("โŒ Failed to load list manager from Redis") - - # Clean up - redis_client.delete("lockable_orm:test_user:test_list_cube:data") - redis_client.delete("lockable_orm:test_user:test_list_cube:lock") - redis_client.delete("lockable_orm:test_user:test_list_cube:version") - redis_client.close() - - except Exception as e: - print(f"โŒ Error in RedisDBManager test: {e}") - - -def modify_list_process(process_id: int, items_to_add: list[str]): - """Function to be run in separate processes to modify the list using merge_items""" - try: - from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager - - # Create Redis connection - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print(f"Process {process_id}: Failed to create Redis connection") - return - - # Create a temporary list manager for this process with items to add - temp_manager = SimpleListManager() - - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=temp_manager, - ) - - print(f"Process {process_id}: Starting modification with items: {items_to_add}") - for item in items_to_add: - db_manager.obj.add_item(item) - # Use sync_with_orm which internally uses merge_items - db_manager.sync_with_orm(size_limit=None) - - print(f"Process {process_id}: Successfully synchronized with Redis") - - redis_client.close() - - except Exception as e: - print(f"Process {process_id}: Error - {e}") - import traceback - - traceback.print_exc() - - -def test_multiprocess_synchronization(): - """Test multiprocess synchronization with RedisDBManager""" - print("\n" + "=" * 60) - print("Testing Multiprocess Synchronization") - print("=" * 60) - - try: - # Initialize Redis with empty list - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("โŒ Failed to create Redis connection") - return - - # Initialize with empty list - initial_manager = SimpleListManager([]) - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=initial_manager, - ) - db_manager.save_to_db(initial_manager) - print("โœ… Initialized empty list manager in Redis") - - # Define items for each process to add - process_items = [ - ["item1", "item2"], - ["item3", "item4"], - ["item5", "item6"], - ["item1", "item7"], # item1 is duplicate, should not be added twice - ] - - # Create and start processes - processes = [] - for i, items in enumerate(process_items): - p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items)) - processes.append(p) - p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - print("\n" + "-" * 40) - print("All processes completed. Checking final result...") - - # Load final result - final_db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=SimpleListManager([]), - ) - final_manager = final_db_manager.load_from_db() - - if final_manager: - print(f"Final synchronized list manager: {final_manager}") - print(f"Final list length: {len(final_manager)}") - print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}") - print(f"Actual items: {set(final_manager.items)}") - - # Check if all unique items are present - expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"} - actual_items = set(final_manager.items) - - if expected_items == actual_items: - print("โœ… All processes contributed correctly - synchronization successful!") - else: - print(f"โŒ Expected items: {expected_items}") - print(f" Actual items: {actual_items}") - else: - print("โŒ Failed to load final result") - - # Clean up - redis_client.delete("lockable_orm:test_user:multiprocess_list:data") - redis_client.delete("lockable_orm:test_user:multiprocess_list:lock") - redis_client.delete("lockable_orm:test_user:multiprocess_list:version") - redis_client.close() - - except Exception as e: - print(f"โŒ Error in multiprocess synchronization test: {e}") - - -def main(): - """Main function to run all tests""" - print("ORM Examples - Environment Variable Loading Tests") - print("=" * 80) - - # Test environment variables display - test_environment_variables() - - # Test manual environment loading - test_manual_env_loading() - - # Test MySQL engine loading - test_mysql_engine_from_env() - - # Test Redis connection loading - test_redis_connection_from_env() - - # Test RedisLockableORM with list[str] - test_redis_lockable_orm_with_list() - - # Test multiprocess synchronization - test_multiprocess_synchronization() - - print("\n" + "=" * 80) - print("All tests completed!") - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py index 1660d6c02..2c3801539 100644 --- a/examples/mem_scheduler/redis_example.py +++ b/examples/mem_scheduler/redis_example.py @@ -22,7 +22,7 @@ sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory -async def service_run(): +def service_run(): # Init example_scheduler_config_path = ( f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" @@ -60,11 +60,11 @@ async def service_run(): content=query, timestamp=datetime.now(), ) - res = await mem_scheduler.redis_add_message_stream(message=message_item.to_dict()) + res = mem_scheduler.redis_add_message_stream(message=message_item.to_dict()) print( f"Added: {res}", ) - await asyncio.sleep(0.5) + asyncio.sleep(0.5) mem_scheduler.redis_stop_listening() @@ -72,4 +72,4 @@ async def service_run(): if __name__ == "__main__": - asyncio.run(service_run()) + service_run() diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index de99f1c95..4aedac711 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -176,6 +176,7 @@ def show_web_logs(mem_scheduler: GeneralScheduler): mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) + mos.mem_scheduler.current_mem_cube = mem_cube mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index f02edaad6..a276fa63d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -175,7 +175,7 @@ def start_config_watch(cls): @classmethod def start_watch_if_enabled(cls) -> None: enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true" - print("enable:", enable) + logger.info(f"NACOS_ENABLE_WATCH: {enable}") if not enable: return interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) @@ -623,7 +623,10 @@ def get_scheduler_config() -> dict[str, Any]: "MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true" ).lower() == "true", - "enable_activation_memory": True, + "enable_activation_memory": os.getenv( + "MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY", "false" + ).lower() + == "true", }, } diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 0412754c3..3b1ce2fc9 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -171,7 +171,9 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") + mode: SearchMode = Field( + SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" + ) internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 8df383bfb..ad43a07e4 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -256,12 +256,14 @@ def init_server(): db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, ) - mem_scheduler.current_mem_cube = naive_mem_cube - mem_scheduler.start() + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) # Initialize SchedulerAPIModule api_module = mem_scheduler.api_module + if os.getenv("API_SCHEDULER_ON", True): + mem_scheduler.start() + return ( graph_db, mem_reader, @@ -357,8 +359,10 @@ def search_memories(search_req: APISearchRequest): "pref_mem": [], "pref_note": "", } - - search_mode = search_req.mode + if search_req.mode == SearchMode.NOT_INITIALIZED: + search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) + else: + search_mode = search_req.mode def _search_text(): if search_mode == SearchMode.FAST: @@ -444,22 +448,38 @@ def fine_search_memories( target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( + searcher = mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + fast_retrieved_memories = searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FINE, + mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, + info=info, ) - formatted_memories = [_format_memory_item(data) for data in search_results] + + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + formatted_memories = [_format_memory_item(data) for data in enhanced_results] return formatted_memories diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index e757f243b..afdaf6871 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -12,10 +12,13 @@ BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, + DEFAULT_CONSUME_BATCH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, @@ -43,6 +46,11 @@ class BaseSchedulerConfig(BaseConfig): gt=0, description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})", ) + consume_batch: int = Field( + default=DEFAULT_CONSUME_BATCH, + gt=0, + description=f"Number of messages to consume in each batch (default: {DEFAULT_CONSUME_BATCH})", + ) auth_config_path: str | None = Field( default=None, description="Path to the authentication configuration file containing private credentials", @@ -91,6 +99,17 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): description="Capacity of the activation memory monitor", ) + # Memory enhancement concurrency & retries configuration + enhance_batch_size: int | None = Field( + default=DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + description="Batch size for concurrent memory enhancement; None or <=1 disables batching", + ) + enhance_retries: int = Field( + default=DEFAULT_SCHEDULER_RETRIEVER_RETRIES, + ge=0, + description="Number of retry attempts per enhancement batch", + ) + # Database configuration for ORM persistence db_path: str | None = Field( default=None, diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 97ff9879f..1b6d4e126 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -283,7 +283,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=QUERY_LABEL, content=query, timestamp=datetime.utcnow(), @@ -344,7 +343,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=response, timestamp=datetime.utcnow(), @@ -768,12 +766,10 @@ def process_textual_memory(): ) # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "async": message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=MEM_READ_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -783,7 +779,6 @@ def process_textual_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -797,7 +792,6 @@ def process_preference_memory(): and self.mem_cubes[mem_cube_id].pref_mem ): messages_list = [messages] - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "sync": pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, @@ -816,7 +810,6 @@ def process_preference_memory(): user_id=target_user_id, session_id=target_session_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=PREF_ADD_LABEL, content=json.dumps(messages_list), timestamp=datetime.utcnow(), @@ -867,12 +860,10 @@ def process_preference_memory(): # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "async": message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=MEM_READ_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -882,7 +873,6 @@ def process_preference_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -909,11 +899,9 @@ def process_preference_memory(): # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 6fc64c5e3..0114fc0da 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -205,7 +205,6 @@ def _chat_with_cot_enhancement( # Step 7: Submit message to scheduler (same as core method) if len(accessible_cubes) == 1: mem_cube_id = accessible_cubes[0].cube_id - mem_cube = self.mem_cubes[mem_cube_id] if self.enable_mem_scheduler and self.mem_scheduler is not None: from datetime import datetime @@ -217,7 +216,6 @@ def _chat_with_cot_enhancement( message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=enhanced_response, timestamp=datetime.now().isoformat(), diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 89e468bd7..cea8c89af 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -609,7 +609,6 @@ def _send_message_to_scheduler( message_item = ScheduleMessageItem( user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.mem_cubes[mem_cube_id], label=label, content=query, timestamp=datetime.utcnow(), diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 28ca182e5..085025b7f 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -7,7 +7,6 @@ import http.client import json -import time from typing import Any from urllib.parse import urlparse @@ -15,6 +14,7 @@ import requests from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import SearchMode logger = get_logger(__name__) @@ -487,7 +487,7 @@ def search_in_conversation(self, query, mode="fast", top_k=10, include_history=T return result - def test_continuous_conversation(self): + def test_continuous_conversation(self, mode=SearchMode.MIXTURE): """Test continuous conversation functionality""" print("=" * 80) print("Testing Continuous Conversation Functionality") @@ -542,15 +542,15 @@ def test_continuous_conversation(self): # Search for trip-related information self.search_in_conversation( - query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5 + query="New Year's Eve Shanghai recommendations", mode=mode, top_k=5 ) # Search for food-related information - self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3) + self.search_in_conversation(query="budget food Shanghai", mode=mode, top_k=3) # Search without conversation history self.search_in_conversation( - query="Shanghai travel", mode="mixture", top_k=3, include_history=False + query="Shanghai travel", mode=mode, top_k=3, include_history=False ) print("\nโœ… Continuous conversation test completed successfully!") @@ -645,7 +645,7 @@ def create_test_add_request( operation=None, ) - def run_all_tests(self): + def run_all_tests(self, mode=SearchMode.MIXTURE): """Run all available tests""" print("๐Ÿš€ Starting comprehensive test suite") print("=" * 80) @@ -653,8 +653,7 @@ def run_all_tests(self): # Test continuous conversation functionality print("\n๐Ÿ’ฌ Testing CONTINUOUS CONVERSATION functions:") try: - self.test_continuous_conversation() - time.sleep(5) + self.test_continuous_conversation(mode=mode) print("โœ… Continuous conversation test completed successfully") except Exception as e: print(f"โŒ Continuous conversation test failed: {e}") @@ -682,7 +681,7 @@ def run_all_tests(self): print("Using direct test mode") try: direct_analyzer = DirectSearchMemoriesAnalyzer() - direct_analyzer.run_all_tests() + direct_analyzer.run_all_tests(mode=SearchMode.MIXTURE) except Exception as e: print(f"Direct test mode failed: {e}") import traceback diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py new file mode 100644 index 000000000..d37e17456 --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py @@ -0,0 +1,1322 @@ +""" +Evaluation Analyzer for Bad Cases + +This module provides the EvalAnalyzer class that extracts bad cases from evaluation results +and analyzes whether memories contain sufficient information to answer golden answers. +""" + +import json +import os +import sys + +from pathlib import Path +from typing import Any + +from openai import OpenAI + +from memos.api.routers.server_router import mem_scheduler +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryMetadata +from memos.memories.textual.tree import TextualMemoryItem + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent # Go up to project root +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + +logger = get_logger(__name__) + + +class EvalAnalyzer: + """ + Evaluation Analyzer class for extracting and analyzing bad cases. + + This class extracts bad cases from evaluation results and uses LLM to analyze + whether memories contain sufficient information to answer golden answers. + """ + + def __init__( + self, + openai_api_key: str | None = None, + openai_base_url: str | None = None, + openai_model: str = "gpt-4o-mini", + output_dir: str = "./tmp/eval_analyzer", + ): + """ + Initialize the EvalAnalyzer. + + Args: + openai_api_key: OpenAI API key + openai_base_url: OpenAI base URL + openai_model: OpenAI model to use + output_dir: Output directory for results + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Initialize OpenAI client + self.openai_client = OpenAI( + api_key=openai_api_key or os.getenv("MEMSCHEDULER_OPENAI_API_KEY"), + base_url=openai_base_url or os.getenv("MEMSCHEDULER_OPENAI_BASE_URL"), + ) + self.openai_model = openai_model or os.getenv( + "MEMSCHEDULER_OPENAI_DEFAULT_MODEL", "gpt-4o-mini" + ) + + logger.info(f"EvalAnalyzer initialized with model: {self.openai_model}") + + def load_json_file(self, filepath: str) -> Any: + """Load JSON file safely.""" + try: + with open(filepath, encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + logger.error(f"File not found: {filepath}") + return None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in {filepath}: {e}") + return None + + def extract_bad_cases(self, judged_file: str, search_results_file: str) -> list[dict[str, Any]]: + """ + Extract bad cases from judged results and corresponding search results. + + Args: + judged_file: Path to the judged results JSON file + search_results_file: Path to the search results JSON file + + Returns: + List of bad cases with their memories + """ + logger.info(f"Loading judged results from: {judged_file}") + judged_data = self.load_json_file(judged_file) + if not judged_data: + return [] + + logger.info(f"Loading search results from: {search_results_file}") + search_data = self.load_json_file(search_results_file) + if not search_data: + return [] + + bad_cases = [] + + # Process each user's data + for user_id, user_judged_results in judged_data.items(): + user_search_results = search_data.get(user_id, []) + + # Create a mapping from query to search context + search_context_map = {} + for search_result in user_search_results: + query = search_result.get("query", "") + context = search_result.get("context", "") + search_context_map[query] = context + + # Process each question for this user + for result in user_judged_results: + # Check if this is a bad case (all judgments are False) + judgments = result.get("llm_judgments", {}) + is_bad_case = all(not judgment for judgment in judgments.values()) + + if is_bad_case: + question = result.get("question", "") + answer = result.get("answer", "") + golden_answer = result.get("golden_answer", "") + + # Find corresponding memories from search results + memories = search_context_map.get(question, "") + + bad_case = { + "user_id": user_id, + "query": question, + "answer": answer, + "golden_answer": golden_answer, + "memories": memories, + "category": result.get("category", 0), + "nlp_metrics": result.get("nlp_metrics", {}), + "response_duration_ms": result.get("response_duration_ms", 0), + "search_duration_ms": result.get("search_duration_ms", 0), + "total_duration_ms": result.get("total_duration_ms", 0), + } + + bad_cases.append(bad_case) + + logger.info(f"Extracted {len(bad_cases)} bad cases") + return bad_cases + + def analyze_memory_sufficiency( + self, query: str, golden_answer: str, memories: str + ) -> dict[str, Any]: + """ + Use LLM to analyze whether memories contain sufficient information to answer the golden answer. + + Args: + query: The original query + golden_answer: The correct answer + memories: The memory context + + Returns: + Analysis result containing sufficiency judgment and relevant memory indices + """ + prompt = f""" +You are an expert analyst tasked with determining whether the provided memories contain sufficient information to answer a specific question correctly. + +**Question:** {query} + +**Golden Answer (Correct Answer):** {golden_answer} + +**Available Memories:** +{memories} + +**Task:** +1. Analyze whether the memories contain enough information to derive the golden answer +2. Identify which specific memory entries (if any) contain relevant information +3. Provide a clear judgment: True if sufficient, False if insufficient + +**Response Format (JSON):** +{{ + "sufficient": true/false, + "confidence": 0.0-1.0, + "relevant_memories": ["memory_1", "memory_2", ...], + "reasoning": "Detailed explanation of your analysis", + "missing_information": "What key information is missing (if insufficient)" +}} + +**Guidelines:** +- Be strict in your evaluation - only mark as sufficient if the memories clearly contain the information needed +- Consider both direct and indirect information that could lead to the golden answer +- Pay attention to dates, names, events, and specific details +- If information is ambiguous or requires significant inference, lean towards insufficient +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are a precise analyst who evaluates information sufficiency.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1000, + ) + + content = response.choices[0].message.content.strip() + + # Try to parse JSON response + try: + # Remove markdown code blocks if present + if content.startswith("```json"): + content = content[7:] + if content.endswith("```"): + content = content[:-3] + content = content.strip() + + analysis = json.loads(content) + return analysis + + except json.JSONDecodeError: + logger.warning(f"Failed to parse LLM response as JSON: {content}") + return { + "sufficient": False, + "confidence": 0.0, + "relevant_memories": [], + "reasoning": f"Failed to parse LLM response: {content}", + "missing_information": "Analysis failed", + } + + except Exception as e: + logger.error(f"Error in LLM analysis: {e}") + return { + "sufficient": False, + "confidence": 0.0, + "relevant_memories": [], + "reasoning": f"Error occurred: {e!s}", + "missing_information": "Analysis failed due to error", + } + + def process_memories_with_llm( + self, memories: str, query: str, processing_type: str = "summarize" + ) -> dict[str, Any]: + """ + Use LLM to process memories for better question answering. + + Args: + memories: The raw memory content + query: The query that will be answered using these memories + processing_type: Type of processing ("summarize", "restructure", "enhance") + + Returns: + Dictionary containing processed memories and processing metadata + """ + if processing_type == "summarize": + prompt = f""" +You are an expert at summarizing and organizing information to help answer specific questions. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Summarize and organize the above memories in a way that would be most helpful for answering the target question. Focus on: +1. Key facts and information relevant to the question +2. Important relationships and connections +3. Chronological or logical organization where applicable +4. Remove redundant or irrelevant information + +**Processed Memories:** +""" + elif processing_type == "restructure": + prompt = f""" +You are an expert at restructuring information to optimize question answering. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Restructure the above memories into a clear, logical format that directly supports answering the target question. Organize by: +1. Most relevant information first +2. Supporting details and context +3. Clear categorization of different types of information +4. Logical flow that leads to the answer + +**Restructured Memories:** +""" + elif processing_type == "enhance": + prompt = f""" +You are an expert at enhancing information by adding context and making connections. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Enhance the above memories by: +1. Making implicit connections explicit +2. Adding relevant context that helps answer the question +3. Highlighting key relationships between different pieces of information +4. Organizing information in a question-focused manner + +**Enhanced Memories:** +""" + else: + raise ValueError(f"Unknown processing_type: {processing_type}") + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are an expert information processor who optimizes content for question answering.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + max_tokens=2000, + ) + + processed_memories = response.choices[0].message.content.strip() + + return { + "processed_memories": processed_memories, + "processing_type": processing_type, + "original_length": len(memories), + "processed_length": len(processed_memories), + "compression_ratio": len(processed_memories) / len(memories) + if len(memories) > 0 + else 0, + } + + except Exception as e: + logger.error(f"Error in memory processing: {e}") + return { + "processed_memories": memories, # Fallback to original + "processing_type": processing_type, + "original_length": len(memories), + "processed_length": len(memories), + "compression_ratio": 1.0, + "error": str(e), + } + + def generate_answer_with_memories( + self, query: str, memories: str, memory_type: str = "original" + ) -> dict[str, Any]: + """ + Generate an answer to the query using the provided memories. + + Args: + query: The question to answer + memories: The memory content to use + memory_type: Type of memories ("original", "processed") + + Returns: + Dictionary containing the generated answer and metadata + """ + prompt = f""" + You are a knowledgeable and helpful AI assistant. + + # CONTEXT: + You have access to memories from two speakers in a conversation. These memories contain + timestamped information that may be relevant to answering the question. + + # INSTRUCTIONS: + 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer. + 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth. + 3. If the question asks about a specific event or fact, look for direct evidence in the memories. + 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description). + 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. + 6. Always convert relative time references to specific dates, months, or years in your final answer. + 7. Do not confuse character names mentioned in memories with the actual users who created them. + 8. The answer must be brief (under 5-6 words) and direct, with no extra description. + + # APPROACH (Think step by step): + 1. First, examine all memories that contain information related to the question. + 2. Synthesize findings from multiple memories if a single entry is insufficient. + 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events. + 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation. + 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). + 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. + 7. Ensure your final answer is specific and avoids vague time references. + + {memories} + + Question: {query} + + Answer: +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are a precise assistant who answers questions based only on provided information.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1000, + ) + + answer = response.choices[0].message.content.strip() + + return { + "answer": answer, + "memory_type": memory_type, + "query": query, + "memory_length": len(memories), + "answer_length": len(answer), + } + + except Exception as e: + logger.error(f"Error in answer generation: {e}") + return { + "answer": f"Error generating answer: {e!s}", + "memory_type": memory_type, + "query": query, + "memory_length": len(memories), + "answer_length": 0, + "error": str(e), + } + + def compare_answer_quality( + self, query: str, golden_answer: str, original_answer: str, processed_answer: str + ) -> dict[str, Any]: + """ + Compare the quality of answers generated from original vs processed memories. + + Args: + query: The original query + golden_answer: The correct/expected answer + original_answer: Answer generated from original memories + processed_answer: Answer generated from processed memories + + Returns: + Dictionary containing comparison results + """ + prompt = f""" +You are an expert evaluator comparing the quality of two answers against a golden standard. + +**Question:** {query} + +**Golden Answer (Correct):** {golden_answer} + +**Answer A (Original Memories):** {original_answer} + +**Answer B (Processed Memories):** {processed_answer} + +**Task:** +Compare both answers against the golden answer and evaluate: +1. Accuracy: How correct is each answer? +2. Completeness: How complete is each answer? +3. Relevance: How relevant is each answer to the question? +4. Clarity: How clear and well-structured is each answer? + +**Response Format (JSON):** +{{ + "original_scores": {{ + "accuracy": 0.0-1.0, + "completeness": 0.0-1.0, + "relevance": 0.0-1.0, + "clarity": 0.0-1.0, + "overall": 0.0-1.0 + }}, + "processed_scores": {{ + "accuracy": 0.0-1.0, + "completeness": 0.0-1.0, + "relevance": 0.0-1.0, + "clarity": 0.0-1.0, + "overall": 0.0-1.0 + }}, + "winner": "original|processed|tie", + "improvement": 0.0-1.0, + "reasoning": "Detailed explanation of the comparison" +}} +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are an expert evaluator who compares answer quality objectively.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1500, + ) + + content = response.choices[0].message.content.strip() + + # Try to parse JSON response + try: + if content.startswith("```json"): + content = content[7:] + if content.endswith("```"): + content = content[:-3] + content = content.strip() + + comparison = json.loads(content) + return comparison + + except json.JSONDecodeError: + logger.warning(f"Failed to parse comparison response as JSON: {content}") + return { + "original_scores": { + "accuracy": 0.5, + "completeness": 0.5, + "relevance": 0.5, + "clarity": 0.5, + "overall": 0.5, + }, + "processed_scores": { + "accuracy": 0.5, + "completeness": 0.5, + "relevance": 0.5, + "clarity": 0.5, + "overall": 0.5, + }, + "winner": "tie", + "improvement": 0.0, + "reasoning": f"Failed to parse comparison: {content}", + } + + except Exception as e: + logger.error(f"Error in answer comparison: {e}") + return { + "original_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + "processed_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + "winner": "tie", + "improvement": 0.0, + "reasoning": f"Error occurred: {e!s}", + } + + def analyze_memory_processing_effectiveness( + self, + bad_cases: list[dict[str, Any]], + processing_types: list[str] | None = None, + ) -> dict[str, Any]: + """ + Analyze the effectiveness of different memory processing techniques. + + Args: + bad_cases: List of bad cases to analyze + processing_types: List of processing types to test + + Returns: + Dictionary containing comprehensive analysis results + """ + if processing_types is None: + processing_types = ["summarize", "restructure", "enhance"] + results = {"processing_results": [], "statistics": {}, "processing_types": processing_types} + + for i, case in enumerate(bad_cases): + logger.info(f"Processing case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + case_result = { + "case_id": i, + "query": case["query"], + "golden_answer": case["golden_answer"], + "original_memories": case["memories"], + "processing_results": {}, + } + + # Generate answer with original memories + original_answer_result = self.generate_answer_with_memories( + case["query"], case["memories"], "original" + ) + case_result["original_answer"] = original_answer_result + + # Test each processing type + for processing_type in processing_types: + logger.info(f" Testing {processing_type} processing...") + + # Process memories + processing_result = self.process_memories_with_llm( + case["memories"], case["query"], processing_type + ) + + # Generate answer with processed memories + processed_answer_result = self.generate_answer_with_memories( + case["query"], + processing_result["processed_memories"], + f"processed_{processing_type}", + ) + + # Compare answer quality + comparison_result = self.compare_answer_quality( + case["query"], + case["golden_answer"], + original_answer_result["answer"], + processed_answer_result["answer"], + ) + + case_result["processing_results"][processing_type] = { + "processing": processing_result, + "answer": processed_answer_result, + "comparison": comparison_result, + } + + results["processing_results"].append(case_result) + + # Calculate statistics + self._calculate_processing_statistics(results) + + return results + + def _calculate_processing_statistics(self, results: dict[str, Any]) -> None: + """Calculate statistics for processing effectiveness analysis.""" + processing_types = results["processing_types"] + processing_results = results["processing_results"] + + if not processing_results: + results["statistics"] = {} + return + + stats = {"total_cases": len(processing_results), "processing_type_stats": {}} + + for processing_type in processing_types: + type_stats = { + "wins": 0, + "ties": 0, + "losses": 0, + "avg_improvement": 0.0, + "avg_compression_ratio": 0.0, + "avg_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + } + + valid_cases = [] + for case in processing_results: + if processing_type in case["processing_results"]: + result = case["processing_results"][processing_type] + comparison = result["comparison"] + + # Count wins/ties/losses + if comparison["winner"] == "processed": + type_stats["wins"] += 1 + elif comparison["winner"] == "tie": + type_stats["ties"] += 1 + else: + type_stats["losses"] += 1 + + valid_cases.append(result) + + if valid_cases: + # Calculate averages + type_stats["avg_improvement"] = sum( + case["comparison"]["improvement"] for case in valid_cases + ) / len(valid_cases) + + type_stats["avg_compression_ratio"] = sum( + case["processing"]["compression_ratio"] for case in valid_cases + ) / len(valid_cases) + + # Calculate average scores + for score_type in type_stats["avg_scores"]: + type_stats["avg_scores"][score_type] = sum( + case["comparison"]["processed_scores"][score_type] for case in valid_cases + ) / len(valid_cases) + + # Calculate win rate + total_decisions = type_stats["wins"] + type_stats["ties"] + type_stats["losses"] + type_stats["win_rate"] = ( + type_stats["wins"] / total_decisions if total_decisions > 0 else 0.0 + ) + type_stats["success_rate"] = ( + (type_stats["wins"] + type_stats["ties"]) / total_decisions + if total_decisions > 0 + else 0.0 + ) + + stats["processing_type_stats"][processing_type] = type_stats + + results["statistics"] = stats + + def analyze_bad_cases(self, bad_cases: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Analyze all bad cases to determine memory sufficiency. + + Args: + bad_cases: List of bad cases to analyze + + Returns: + List of analyzed bad cases with sufficiency information + """ + analyzed_cases = [] + + for i, case in enumerate(bad_cases): + logger.info(f"Analyzing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + analysis = self.analyze_memory_sufficiency( + case["query"], case["golden_answer"], case["memories"] + ) + + # Add analysis results to the case + analyzed_case = case.copy() + analyzed_case.update( + { + "memory_analysis": analysis, + "has_sufficient_memories": analysis["sufficient"], + "analysis_confidence": analysis["confidence"], + "relevant_memory_count": len(analysis["relevant_memories"]), + } + ) + + analyzed_cases.append(analyzed_case) + + return analyzed_cases + + def collect_bad_cases(self, eval_result_dir: str | None = None) -> dict[str, Any]: + """ + Main method to collect and analyze bad cases from evaluation results. + + Args: + eval_result_dir: Directory containing evaluation results + + Returns: + Dictionary containing analysis results and statistics + """ + if eval_result_dir is None: + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-072005-fast" + + judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") + search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") + + # Extract bad cases + bad_cases = self.extract_bad_cases(judged_file, search_results_file) + + if not bad_cases: + logger.warning("No bad cases found") + return {"bad_cases": [], "statistics": {}} + + # Analyze bad cases + analyzed_cases = self.analyze_bad_cases(bad_cases) + + # Calculate statistics + total_cases = len(analyzed_cases) + sufficient_cases = sum( + 1 for case in analyzed_cases if case.get("has_sufficient_memories", False) + ) + insufficient_cases = total_cases - sufficient_cases + + avg_confidence = ( + sum(case["analysis_confidence"] for case in analyzed_cases) / total_cases + if total_cases > 0 + else 0 + ) + avg_relevant_memories = ( + sum(case["relevant_memory_count"] for case in analyzed_cases) / total_cases + if total_cases > 0 + else 0 + ) + + statistics = { + "total_bad_cases": total_cases, + "sufficient_memory_cases": sufficient_cases, + "insufficient_memory_cases": insufficient_cases, + "sufficiency_rate": sufficient_cases / total_cases if total_cases > 0 else 0, + "average_confidence": avg_confidence, + "average_relevant_memories": avg_relevant_memories, + } + + # Save results + results = { + "bad_cases": analyzed_cases, + "statistics": statistics, + "metadata": { + "eval_result_dir": eval_result_dir, + "judged_file": judged_file, + "search_results_file": search_results_file, + "analysis_model": self.openai_model, + }, + } + + output_file = self.output_dir / "bad_cases_analysis.json" + with open(output_file, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + logger.info(f"Analysis complete. Results saved to: {output_file}") + logger.info(f"Statistics: {statistics}") + + return results + + def _parse_json_response(self, response_text: str) -> dict: + """ + Parse JSON response from LLM, handling various formats and potential errors. + + Args: + response_text: Raw response text from LLM + + Returns: + Parsed JSON dictionary + + Raises: + ValueError: If JSON cannot be parsed + """ + import re + + # Try to extract JSON from response text + # Look for JSON blocks between ```json and ``` or just {} blocks + json_patterns = [r"```json\s*(\{.*?\})\s*```", r"```\s*(\{.*?\})\s*```", r"(\{.*\})"] + + for pattern in json_patterns: + matches = re.findall(pattern, response_text, re.DOTALL) + if matches: + json_str = matches[0].strip() + try: + return json.loads(json_str) + except json.JSONDecodeError: + continue + + # If no JSON pattern found, try parsing the entire response + try: + return json.loads(response_text.strip()) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON response: {response_text[:200]}...") + raise ValueError(f"Invalid JSON response: {e!s}") from e + + def filter_memories_with_llm(self, memories: list[str], query: str) -> tuple[list[str], bool]: + """ + Use LLM to filter memories based on relevance to the query. + + Args: + memories: List of memory strings + query: Query to filter memories against + + Returns: + Tuple of (filtered_memories, success_flag) + """ + if not memories: + return [], True + + # Build prompt for memory filtering + memories_text = "\n".join([f"{i + 1}. {memory}" for i, memory in enumerate(memories)]) + + prompt = f"""You are a memory filtering system. Given a query and a list of memories, identify which memories are relevant and non-redundant for answering the query. + +Query: {query} + +Memories: +{memories_text} + +Please analyze each memory and return a JSON response with the following format: +{{ + "relevant_memory_indices": [list of indices (1-based) of memories that are relevant to the query], + "reasoning": "Brief explanation of your filtering decisions" +}} + +Only include memories that are directly relevant to answering the query. Remove redundant or unrelated memories.""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + ) + + response_text = response.choices[0].message.content + + # Extract JSON from response + result = self._parse_json_response(response_text) + + if "relevant_memory_indices" in result: + relevant_indices = result["relevant_memory_indices"] + filtered_memories = [] + + for idx in relevant_indices: + if 1 <= idx <= len(memories): + filtered_memories.append(memories[idx - 1]) + + logger.info(f"Filtered memories: {len(memories)} -> {len(filtered_memories)}") + return filtered_memories, True + else: + logger.warning("Invalid response format from memory filtering LLM") + return memories, False + + except Exception as e: + logger.error(f"Error in memory filtering: {e}") + return memories, False + + def evaluate_answer_ability_with_llm(self, query: str, memories: list[str]) -> bool: + """ + Use LLM to evaluate whether the given memories can answer the query. + + Args: + query: Query to evaluate + memories: List of memory strings + + Returns: + Boolean indicating whether memories can answer the query + """ + if not memories: + return False + + memories_text = "\n".join([f"- {memory}" for memory in memories]) + + prompt = f"""You are an answer ability evaluator. Given a query and a list of memories, determine whether the memories contain sufficient information to answer the query. + +Query: {query} + +Available Memories: +{memories_text} + +Please analyze the memories and return a JSON response with the following format: +{{ + "can_answer": true/false, + "confidence": 0.0-1.0, + "reasoning": "Brief explanation of your decision" +}} + +Consider whether the memories contain the specific information needed to provide a complete and accurate answer to the query.""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + ) + + response_text = response.choices[0].message.content + result = self._parse_json_response(response_text) + + if "can_answer" in result: + can_answer = result["can_answer"] + confidence = result.get("confidence", 0.5) + reasoning = result.get("reasoning", "No reasoning provided") + + logger.info( + f"Answer ability evaluation: {can_answer} (confidence: {confidence:.2f}) - {reasoning}" + ) + return can_answer + else: + logger.warning("Invalid response format from answer ability evaluation") + return False + + except Exception as e: + logger.error(f"Error in answer ability evaluation: {e}") + return False + + def memory_llm_processing_analysis( + self, bad_cases: list[dict[str, Any]], use_llm_filtering: bool = True + ) -> list[dict[str, Any]]: + """ + Analyze bad cases by processing memories with LLM filtering and testing answer ability. + + This method: + 1. Parses memory strings from bad cases + 2. Uses LLM to filter unrelated and redundant memories + 3. Tests whether processed memories can help answer questions correctly + 4. Compares results before and after LLM processing + + Args: + bad_cases: List of bad cases to analyze + use_llm_filtering: Whether to use LLM filtering + + Returns: + List of analyzed bad cases with LLM processing results + """ + analyzed_cases = [] + + for i, case in enumerate(bad_cases): + logger.info(f"Processing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + try: + # Parse memory string + memories_text = case.get("memories", "") + if not memories_text: + logger.warning(f"No memories found for case {i + 1}") + analyzed_case = case.copy() + analyzed_case.update( + { + "llm_processing_analysis": { + "error": "No memories available", + "original_memories_count": 0, + "processed_memories_count": 0, + "can_answer_with_original": False, + "can_answer_with_processed": False, + "processing_improved_answer": False, + } + } + ) + analyzed_cases.append(analyzed_case) + continue + + # Split memories by lines + memory_lines = [line.strip() for line in memories_text.split("\n") if line.strip()] + original_memories = [line for line in memory_lines if line] + + logger.info(f"Parsed {len(original_memories)} memories from text") + + # Test answer ability with original memories + can_answer_original = self.evaluate_answer_ability_with_llm( + query=case["query"], memories=original_memories + ) + + # Process memories with LLM filtering if enabled + processed_memories = original_memories + processing_success = False + + if use_llm_filtering and len(original_memories) > 0: + processed_memories, processing_success = self.filter_memories_with_llm( + memories=original_memories, query=case["query"] + ) + logger.info( + f"LLM filtering: {len(original_memories)} -> {len(processed_memories)} memories, success: {processing_success}" + ) + + # Test answer ability with processed memories + can_answer_processed = self.evaluate_answer_ability_with_llm( + query=case["query"], memories=processed_memories + ) + + # Determine if processing improved answer ability + processing_improved = can_answer_processed and not can_answer_original + + # Create analysis result + llm_analysis = { + "processing_success": processing_success, + "original_memories_count": len(original_memories), + "processed_memories_count": len(processed_memories), + "memories_removed_count": len(original_memories) - len(processed_memories), + "can_answer_with_original": can_answer_original, + "can_answer_with_processed": can_answer_processed, + "processing_improved_answer": processing_improved, + "original_memories": original_memories, + "processed_memories": processed_memories, + } + + # Add analysis to case + analyzed_case = case.copy() + analyzed_case["llm_processing_analysis"] = llm_analysis + + logger.info( + f"Case {i + 1} analysis complete: " + f"Original: {can_answer_original}, " + f"Processed: {can_answer_processed}, " + f"Improved: {processing_improved}" + ) + + except Exception as e: + logger.error(f"Error processing case {i + 1}: {e}") + analyzed_case = case.copy() + analyzed_case["llm_processing_analysis"] = { + "error": str(e), + "processing_success": False, + "original_memories_count": 0, + "processed_memories_count": 0, + "can_answer_with_original": False, + "can_answer_with_processed": False, + "processing_improved_answer": False, + } + + analyzed_cases.append(analyzed_case) + + return analyzed_cases + + def scheduler_mem_process(self, query, memories): + from memos.mem_scheduler.utils.misc_utils import extract_list_items_in_answer + + _memories = [] + for mem in memories: + mem_item = TextualMemoryItem(memory=mem, metadata=TextualMemoryMetadata()) + _memories.append(mem_item) + prompt = mem_scheduler.retriever._build_enhancement_prompt( + query_history=[query], batch_texts=memories + ) + logger.debug( + f"[Enhance][batch={0}] Prompt (first 200 chars, len={len(prompt)}): {prompt[:200]}..." + ) + + response = mem_scheduler.retriever.process_llm.generate( + [{"role": "user", "content": prompt}] + ) + logger.debug(f"[Enhance][batch={0}] Response (first 200 chars): {response[:200]}...") + + processed_results = extract_list_items_in_answer(response) + + return { + "processed_memories": processed_results, + "processing_type": "enhance", + "original_length": len("\n".join(memories)), + "processed_length": len("\n".join(processed_results)), + "compression_ratio": len("\n".join(processed_results)) / len("\n".join(memories)) + if len(memories) > 0 + else 0, + } + + def analyze_bad_cases_with_llm_processing( + self, + bad_cases: list[dict[str, Any]], + save_results: bool = True, + output_file: str | None = None, + ) -> dict[str, Any]: + """ + Comprehensive analysis of bad cases with LLM memory processing. + + This method performs a complete analysis including: + 1. Basic bad case analysis + 2. LLM memory processing analysis + 3. Statistical summary of improvements + 4. Detailed reporting + + Args: + bad_cases: List of bad cases to analyze + save_results: Whether to save results to file + output_file: Optional output file path + + Returns: + Dictionary containing comprehensive analysis results + """ + from datetime import datetime + + logger.info( + f"Starting comprehensive analysis of {len(bad_cases)} bad cases with LLM processing" + ) + + # Perform LLM memory processing analysis + analyzed_cases = self.memory_llm_processing_analysis( + bad_cases=bad_cases, use_llm_filtering=True + ) + + # Calculate statistics + total_cases = len(analyzed_cases) + successful_processing = 0 + improved_cases = 0 + original_answerable = 0 + processed_answerable = 0 + total_memories_before = 0 + total_memories_after = 0 + + for case in analyzed_cases: + llm_analysis = case.get("llm_processing_analysis", {}) + + if llm_analysis.get("processing_success", False): + successful_processing += 1 + + if llm_analysis.get("processing_improved_answer", False): + improved_cases += 1 + + if llm_analysis.get("can_answer_with_original", False): + original_answerable += 1 + + if llm_analysis.get("can_answer_with_processed", False): + processed_answerable += 1 + + total_memories_before += llm_analysis.get("original_memories_count", 0) + total_memories_after += llm_analysis.get("processed_memories_count", 0) + + # Calculate improvement metrics + processing_success_rate = successful_processing / total_cases if total_cases > 0 else 0 + improvement_rate = improved_cases / total_cases if total_cases > 0 else 0 + original_answer_rate = original_answerable / total_cases if total_cases > 0 else 0 + processed_answer_rate = processed_answerable / total_cases if total_cases > 0 else 0 + memory_reduction_rate = ( + (total_memories_before - total_memories_after) / total_memories_before + if total_memories_before > 0 + else 0 + ) + + # Create comprehensive results + results = { + "analysis_metadata": { + "total_cases_analyzed": total_cases, + "analysis_timestamp": datetime.now().isoformat(), + "llm_model_used": self.openai_model, + }, + "processing_statistics": { + "successful_processing_count": successful_processing, + "processing_success_rate": processing_success_rate, + "cases_with_improvement": improved_cases, + "improvement_rate": improvement_rate, + "original_answerable_cases": original_answerable, + "original_answer_rate": original_answer_rate, + "processed_answerable_cases": processed_answerable, + "processed_answer_rate": processed_answer_rate, + "answer_rate_improvement": processed_answer_rate - original_answer_rate, + }, + "memory_statistics": { + "total_memories_before_processing": total_memories_before, + "total_memories_after_processing": total_memories_after, + "memories_removed": total_memories_before - total_memories_after, + "memory_reduction_rate": memory_reduction_rate, + "average_memories_per_case_before": total_memories_before / total_cases + if total_cases > 0 + else 0, + "average_memories_per_case_after": total_memories_after / total_cases + if total_cases > 0 + else 0, + }, + "analyzed_cases": analyzed_cases, + } + + # Log summary + logger.info("LLM Processing Analysis Summary:") + logger.info(f" - Total cases: {total_cases}") + logger.info(f" - Processing success rate: {processing_success_rate:.2%}") + logger.info(f" - Cases with improvement: {improved_cases} ({improvement_rate:.2%})") + logger.info(f" - Original answer rate: {original_answer_rate:.2%}") + logger.info(f" - Processed answer rate: {processed_answer_rate:.2%}") + logger.info( + f" - Answer rate improvement: {processed_answer_rate - original_answer_rate:.2%}" + ) + logger.info(f" - Memory reduction: {memory_reduction_rate:.2%}") + + # Save results if requested + if save_results: + if output_file is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = f"llm_processing_analysis_{timestamp}.json" + + try: + with open(output_file, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + logger.info(f"Analysis results saved to: {output_file}") + except Exception as e: + logger.error(f"Failed to save results to {output_file}: {e}") + + return results + + +def main(): + """Main test function.""" + print("=== EvalAnalyzer Simple Test ===") + + # Initialize analyzer + analyzer = EvalAnalyzer(output_dir="./tmp/eval_analyzer") + + print("Analyzer initialized") + + # Test file paths + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-xcy-1030-2114-locomo" + judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") + search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") + + print("Testing with files:") + print(f" Judged file: {judged_file}") + print(f" Search results file: {search_results_file}") + + # Check if files exist + if not os.path.exists(judged_file): + print(f"โŒ Judged file not found: {judged_file}") + return + + if not os.path.exists(search_results_file): + print(f"โŒ Search results file not found: {search_results_file}") + return + + print("โœ… Both files exist") + + # Test bad case extraction only + try: + print("\n=== Testing Bad Case Extraction ===") + bad_cases = analyzer.extract_bad_cases(judged_file, search_results_file) + + print(f"โœ… Successfully extracted {len(bad_cases)} bad cases") + + if bad_cases: + print("\n=== Sample Bad Cases ===") + for i, case in enumerate(bad_cases[:3]): # Show first 3 cases + print(f"\nBad Case {i + 1}:") + print(f" User ID: {case['user_id']}") + print(f" Query: {case['query'][:100]}...") + print(f" Golden Answer: {case['golden_answer']}...") + print(f" Answer: {case['answer']}...") + print(f" Has Memories: {len(case['memories']) > 0}") + print(f" Memory Length: {len(case['memories'])} chars") + + # Save basic results without LLM analysis + basic_results = { + "bad_cases_count": len(bad_cases), + "bad_cases": bad_cases, + "metadata": { + "eval_result_dir": eval_result_dir, + "judged_file": judged_file, + "search_results_file": search_results_file, + "extraction_only": True, + }, + } + + output_file = analyzer.output_dir / "bad_cases_extraction_only.json" + import json + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(basic_results, f, indent=2, ensure_ascii=False) + + print(f"\nโœ… Basic extraction results saved to: {output_file}") + + except Exception as e: + print(f"โŒ Error during extraction: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/src/memos/mem_scheduler/analyzer/memory_processing.py b/src/memos/mem_scheduler/analyzer/memory_processing.py new file mode 100644 index 000000000..b692341c2 --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/memory_processing.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Test script for memory processing functionality in eval_analyzer.py + +This script demonstrates how to use the new LLM memory processing features +to analyze and improve memory-based question answering. +""" + +import json +import os +import sys + +from pathlib import Path +from typing import Any + +from memos.log import get_logger +from memos.mem_scheduler.analyzer.eval_analyzer import EvalAnalyzer + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent # Go up to project root +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +logger = get_logger(__name__) + + +def create_sample_bad_cases() -> list[dict[str, Any]]: + """Create sample bad cases for testing memory processing.""" + return [ + { + "query": "What is the capital of France?", + "golden_answer": "Paris", + "memories": """ + Memory 1: France is a country in Western Europe. + Memory 2: The Eiffel Tower is located in Paris. + Memory 3: Paris is known for its art museums and fashion. + Memory 4: French cuisine is famous worldwide. + Memory 5: The Seine River flows through Paris. + """, + }, + { + "query": "When was the iPhone first released?", + "golden_answer": "June 29, 2007", + "memories": """ + Memory 1: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne. + Memory 2: The iPhone was announced by Steve Jobs at the Macworld Conference & Expo on January 9, 2007. + Memory 3: The iPhone went on sale on June 29, 2007. + Memory 4: The original iPhone had a 3.5-inch screen. + Memory 5: Apple's stock price increased significantly after the iPhone launch. + """, + }, + { + "query": "What is photosynthesis?", + "golden_answer": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.", + "memories": """ + Memory 1: Plants are living organisms that need sunlight to grow. + Memory 2: Chlorophyll is the green pigment in plants. + Memory 3: Plants take in carbon dioxide from the air. + Memory 4: Water is absorbed by plant roots from the soil. + Memory 5: Oxygen is released by plants during the day. + Memory 6: Glucose is a type of sugar that plants produce. + """, + }, + ] + + +def memory_processing(bad_cases): + """ + Test the memory processing functionality with cover rate and acc rate analysis. + + This function analyzes: + 1. Cover rate: Whether memories contain all information needed to answer the query + 2. Acc rate: Whether processed memories can correctly answer the query + """ + print("๐Ÿงช Testing Memory Processing Functionality with Cover Rate & Acc Rate Analysis") + print("=" * 80) + + # Initialize analyzer + analyzer = EvalAnalyzer() + + print(f"๐Ÿ“Š Testing with {len(bad_cases)} sample cases") + print() + + # Initialize counters for real-time statistics + total_cases = 0 + cover_count = 0 # Cases where memories cover all needed information + acc_count = 0 # Cases where processed memories can correctly answer + + # Process each case + for i, case in enumerate(bad_cases): + total_cases += 1 + + # Safely handle query display + query_display = str(case.get("query", "Unknown query")) + print(f"๐Ÿ” Case {i + 1}/{len(bad_cases)}: {query_display}...") + + # Safely handle golden_answer display (convert to string if needed) + golden_answer = case.get("golden_answer", "Unknown answer") + golden_answer_str = str(golden_answer) if golden_answer is not None else "Unknown answer" + print(f"๐Ÿ“ Golden Answer: {golden_answer_str}") + print() + + # Step 1: Analyze if memories contain sufficient information (Cover Rate) + print(" ๐Ÿ“‹ Step 1: Analyzing memory coverage...") + coverage_analysis = analyzer.analyze_memory_sufficiency( + case["query"], + golden_answer_str, # Use the string version + case["memories"], + ) + + has_coverage = coverage_analysis.get("sufficient", False) + if has_coverage: + cover_count += 1 + + print(f" โœ… Memory Coverage: {'SUFFICIENT' if has_coverage else 'INSUFFICIENT'}") + print(f" ๐ŸŽฏ Confidence: {coverage_analysis.get('confidence', 0):.2f}") + print(f" ๐Ÿ’ญ Reasoning: {coverage_analysis.get('reasoning', 'N/A')}...") + if not has_coverage: + print( + f" โŒ Missing Info: {coverage_analysis.get('missing_information', 'N/A')[:100]}..." + ) + continue + print() + + # Step 2: Process memories and test answer ability (Acc Rate) + print(" ๐Ÿ”„ Step 2: Processing memories and testing answer ability...") + + processing_result = analyzer.scheduler_mem_process( + query=case["query"], + memories=case["memories"], + ) + print(f"Original Memories: {case['memories']}") + print(f"Processed Memories: {processing_result['processed_memories']}") + print(f" ๐Ÿ“ Compression ratio: {processing_result['compression_ratio']:.2f}") + print(f" ๐Ÿ“„ Processed memories length: {processing_result['processed_length']} chars") + + # Generate answer with processed memories + answer_result = analyzer.generate_answer_with_memories( + case["query"], processing_result["processed_memories"], "processed_enhanced" + ) + + # Evaluate if the generated answer is correct + print(" ๐ŸŽฏ Step 3: Evaluating answer correctness...") + answer_evaluation = analyzer.compare_answer_quality( + case["query"], + golden_answer_str, # Use the string version + "No original answer available", # We don't have original answer + answer_result["answer"], + ) + + # Determine if processed memories can correctly answer (simplified logic) + processed_accuracy = answer_evaluation.get("processed_scores", {}).get("accuracy", 0) + can_answer_correctly = processed_accuracy >= 0.7 # Threshold for "correct" answer + + if can_answer_correctly: + acc_count += 1 + + print(f" ๐Ÿ’ฌ Generated Answer: {answer_result['answer']}...") + print( + f" โœ… Answer Accuracy: {'CORRECT' if can_answer_correctly else 'INCORRECT'} (score: {processed_accuracy:.2f})" + ) + print() + + # Calculate and print real-time rates + current_cover_rate = cover_count / total_cases + current_acc_rate = acc_count / total_cases + + print(" ๐Ÿ“Š REAL-TIME STATISTICS:") + print(f" ๐ŸŽฏ Cover Rate: {current_cover_rate:.2%} ({cover_count}/{total_cases})") + print(f" โœ… Acc Rate: {current_acc_rate:.2%} ({acc_count}/{total_cases})") + print() + + print("-" * 80) + print() + + # Final summary + print("๐Ÿ FINAL ANALYSIS SUMMARY") + print("=" * 80) + print(f"๐Ÿ“Š Total Cases Processed: {total_cases}") + print(f"๐ŸŽฏ Final Cover Rate: {cover_count / total_cases:.2%} ({cover_count}/{total_cases})") + print(f" - Cases with sufficient memory coverage: {cover_count}") + print(f" - Cases with insufficient memory coverage: {total_cases - cover_count}") + print() + print(f"โœ… Final Acc Rate: {acc_count / total_cases:.2%} ({acc_count}/{total_cases})") + print(f" - Cases where processed memories can answer correctly: {acc_count}") + print(f" - Cases where processed memories cannot answer correctly: {total_cases - acc_count}") + print() + + # Additional insights + if cover_count > 0: + effective_processing_rate = acc_count / cover_count if cover_count > 0 else 0 + print(f"๐Ÿ”„ Processing Effectiveness: {effective_processing_rate:.2%}") + print( + f" - Among cases with sufficient coverage, {effective_processing_rate:.1%} can be answered correctly after processing" + ) + + print("=" * 80) + + +def load_real_bad_cases(file_path: str) -> list[dict[str, Any]]: + """Load real bad cases from JSON file.""" + print(f"๐Ÿ“‚ Loading bad cases from: {file_path}") + + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + + bad_cases = data.get("bad_cases", []) + print(f"โœ… Loaded {len(bad_cases)} bad cases") + + return bad_cases + + +def main(): + """Main test function.""" + print("๐Ÿš€ Memory Processing Test Suite") + print("=" * 60) + print() + + # Check if OpenAI API key is set + if not os.getenv("OPENAI_API_KEY"): + print("โš ๏ธ Warning: OPENAI_API_KEY not found in environment variables") + print(" Please set your OpenAI API key to run the tests") + return + + try: + bad_cases_file = f"{BASE_DIR}/tmp/eval_analyzer/bad_cases_extraction_only.json" + bad_cases = load_real_bad_cases(bad_cases_file) + + print(f"โœ… Created {len(bad_cases)} sample bad cases") + print() + + # Run memory processing tests + memory_processing(bad_cases) + + print("โœ… All tests completed successfully!") + + except Exception as e: + print(f"โŒ Test failed with error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index ace67eff6..03e1fc778 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -427,7 +427,6 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=QUERY_LABEL, content=query, timestamp=datetime.now(), @@ -518,7 +517,6 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=response, timestamp=datetime.now(), diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py index 7c0fa5a4a..3d0235871 100644 --- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py +++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py @@ -226,9 +226,9 @@ def evaluate_memory_answer_ability( try: # Extract JSON response - from memos.mem_scheduler.utils.misc_utils import extract_json_dict + from memos.mem_scheduler.utils.misc_utils import extract_json_obj - result = extract_json_dict(response) + result = extract_json_obj(response) # Validate response structure if "result" in result: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 028fe8e3f..eb49d0238 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,6 +1,5 @@ import contextlib import multiprocessing -import queue import threading import time @@ -18,15 +17,18 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue +from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, + DEFAULT_CONSUME_BATCH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + DEFAULT_MAX_WEB_LOG_QUEUE_SIZE, DEFAULT_STARTUP_MODE, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, @@ -86,6 +88,22 @@ def __init__(self, config: BaseSchedulerConfig): "scheduler_startup_mode", DEFAULT_STARTUP_MODE ) + # message queue configuration + self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) + self.max_internal_message_queue_size = self.config.get( + "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE + ) + + # Initialize message queue based on configuration + if self.use_redis_queue: + self.memos_message_queue = SchedulerRedisQueue( + maxsize=self.max_internal_message_queue_size + ) + else: + self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( + maxsize=self.max_internal_message_queue_size + ) + self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None @@ -93,6 +111,8 @@ def __init__(self, config: BaseSchedulerConfig): self.mem_reader = None # Will be set by MOSCore self.dispatcher = SchedulerDispatcher( config=self.config, + memos_message_queue=self.memos_message_queue, + use_redis_queue=self.use_redis_queue, max_workers=self.thread_pool_max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch, ) @@ -100,23 +120,9 @@ def __init__(self, config: BaseSchedulerConfig): # optional configs self.disable_handlers: list | None = self.config.get("disable_handlers", None) - # message queue configuration - self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) - self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE + self.max_web_log_queue_size = self.config.get( + "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE ) - - # Initialize message queue based on configuration - if self.use_redis_queue: - self.memos_message_queue = None # Will use Redis instead - # Initialize Redis if using Redis queue with auto-initialization - self.auto_initialize_redis() - else: - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size - ) - - self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50) self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( maxsize=self.max_web_log_queue_size ) @@ -126,6 +132,7 @@ def __init__(self, config: BaseSchedulerConfig): self._consume_interval = self.config.get( "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS ) + self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) # other attributes self._context_lock = threading.Lock() @@ -216,7 +223,7 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: with self._context_lock: self.current_user_id = msg.user_id self.current_mem_cube_id = msg.mem_cube_id - self.current_mem_cube = msg.mem_cube + self.current_mem_cube = self.get_mem_cube(msg.mem_cube_id) def transform_working_memories_to_monitors( self, query_keywords, memories: list[TextualMemoryItem] @@ -533,17 +540,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue + self.memos_message_queue.put(message) + logger.info(f"Submitted message to local queue: {message.label} - {message.content}") - if self.use_redis_queue: - # Use Redis stream for message queue - self.redis_add_message_stream(message.to_dict()) - logger.info(f"Submitted message to Redis: {message.label} - {message.content}") - else: - # Use local queue - self.memos_message_queue.put(message) - logger.info( - f"Submitted message to local queue: {message.label} - {message.content}" - ) with contextlib.suppress(Exception): if messages: self.dispatcher.on_messages_enqueued(messages) @@ -590,7 +589,7 @@ def get_web_log_messages(self) -> list[dict]: try: item = self._web_log_message_queue.get_nowait() # Thread-safe get messages.append(item.to_dict()) - except queue.Empty: + except Exception: break return messages @@ -601,62 +600,28 @@ def _message_consumer(self) -> None: Runs in a dedicated thread to process messages at regular intervals. For Redis queue, this method starts the Redis listener. """ - if self.use_redis_queue: - # For Redis queue, start the Redis listener - def redis_message_handler(message_data): - """Handler for Redis messages""" - try: - # Redis message data needs to be decoded from bytes to string - decoded_data = {} - for key, value in message_data.items(): - if isinstance(key, bytes): - key = key.decode("utf-8") - if isinstance(value, bytes): - value = value.decode("utf-8") - decoded_data[key] = value - - message = ScheduleMessageItem.from_dict(decoded_data) - self.dispatcher.dispatch([message]) - except Exception as e: - logger.error(f"Error processing Redis message: {e}") - logger.error(f"Message data: {message_data}") - - self.redis_start_listening(handler=redis_message_handler) - - # Keep the thread alive while Redis listener is running - while self._running: - time.sleep(self._consume_interval) - else: - # Original local queue logic - while self._running: # Use a running flag for graceful shutdown - try: - # Get all available messages at once (thread-safe approach) - messages = [] - while True: - try: - # Use get_nowait() directly without empty() check to avoid race conditions - message = self.memos_message_queue.get_nowait() - messages.append(message) - except queue.Empty: - # No more messages available - break - if messages: - try: - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - finally: - # Mark all messages as processed - for _ in messages: - self.memos_message_queue.task_done() + # Original local queue logic + while self._running: # Use a running flag for graceful shutdown + try: + # Get messages in batches based on consume_batch setting + + messages = self.memos_message_queue.get(block=True, batch_size=self.consume_batch) + + if messages: + try: + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error(f"Error dispatching messages: {e!s}") - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed + # Sleep briefly to prevent busy waiting + time.sleep(self._consume_interval) # Adjust interval as needed - except Exception as e: + except Exception as e: + # Don't log error for "No messages available in Redis queue" as it's expected + if "No messages available in Redis queue" not in str(e): logger.error(f"Unexpected error in message consumer: {e!s}") - time.sleep(self._consume_interval) # Prevent tight error loops + time.sleep(self._consume_interval) # Prevent tight error loops def start(self) -> None: """ @@ -666,16 +631,25 @@ def start(self) -> None: 1. Message consumer thread or process (based on startup_mode) 2. Dispatcher thread pool (if parallel dispatch enabled) """ - if self._running: - logger.warning("Memory Scheduler is already running") - return - # Initialize dispatcher resources if self.enable_parallel_dispatch: logger.info( f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers" ) + self.start_consumer() + + def start_consumer(self) -> None: + """ + Start only the message consumer thread/process. + + This method can be used to restart the consumer after it has been stopped + with stop_consumer(), without affecting other scheduler components. + """ + if self._running: + logger.warning("Memory Scheduler consumer is already running") + return + # Start consumer based on startup mode self._running = True @@ -698,15 +672,15 @@ def start(self) -> None: self._consumer_thread.start() logger.info("Message consumer thread started") - def stop(self) -> None: - """Stop all scheduler components gracefully. + def stop_consumer(self) -> None: + """Stop only the message consumer thread/process gracefully. - 1. Stops message consumer thread/process - 2. Shuts down dispatcher thread pool - 3. Cleans up resources + This method stops the consumer without affecting other components like + dispatcher or monitors. Useful when you want to pause message processing + while keeping other scheduler components running. """ if not self._running: - logger.warning("Memory Scheduler is not running") + logger.warning("Memory Scheduler consumer is not running") return # Signal consumer thread/process to stop @@ -726,12 +700,30 @@ def stop(self) -> None: logger.info("Consumer process terminated") else: logger.info("Consumer process stopped") + self._consumer_process = None elif self._consumer_thread and self._consumer_thread.is_alive(): self._consumer_thread.join(timeout=5.0) if self._consumer_thread.is_alive(): logger.warning("Consumer thread did not stop gracefully") else: logger.info("Consumer thread stopped") + self._consumer_thread = None + + logger.info("Memory Scheduler consumer stopped") + + def stop(self) -> None: + """Stop all scheduler components gracefully. + + 1. Stops message consumer thread/process + 2. Shuts down dispatcher thread pool + 3. Cleans up resources + """ + if not self._running: + logger.warning("Memory Scheduler is not running") + return + + # Stop consumer first + self.stop_consumer() # Shutdown dispatcher if self.dispatcher: @@ -743,10 +735,6 @@ def stop(self) -> None: logger.info("Shutting down monitor...") self.dispatcher_monitor.stop() - # Clean up queues - self._cleanup_queues() - logger.info("Memory Scheduler stopped completely") - @property def handlers(self) -> dict[str, Callable]: """ @@ -819,30 +807,6 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di return result - def _cleanup_queues(self) -> None: - """Ensure all queues are emptied and marked as closed.""" - if self.use_redis_queue: - # For Redis queue, stop the listener and close connection - try: - self.redis_stop_listening() - self.redis_close() - except Exception as e: - logger.error(f"Error cleaning up Redis connection: {e}") - else: - # Original local queue cleanup - try: - while not self.memos_message_queue.empty(): - self.memos_message_queue.get_nowait() - self.memos_message_queue.task_done() - except queue.Empty: - pass - - try: - while not self._web_log_message_queue.empty(): - self._web_log_message_queue.get_nowait() - except queue.Empty: - pass - def mem_scheduler_wait( self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01 ) -> bool: @@ -906,11 +870,24 @@ def _fmt_eta(seconds: float | None) -> str: st = ( stats_fn() ) # expected: {'pending':int,'running':int,'done':int?,'rate':float?} - pend = int(st.get("pending", 0)) run = int(st.get("running", 0)) + except Exception: pass + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + # For Redis queue, prefer XINFO GROUPS to compute pending + groups_info = self.memos_message_queue.redis.xinfo_groups( + self.memos_message_queue.stream_name + ) + if groups_info: + for group in groups_info: + if group.get("name") == self.memos_message_queue.consumer_group: + pend = int(group.get("pending", pend)) + break + else: + pend = run + # 2) dynamic total (allows new tasks queued while waiting) total_now = max(init_unfinished, done_total + curr_unfinished) done_total = max(0, total_now - curr_unfinished) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index c2407b9e6..b74529c8c 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -10,7 +10,9 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.task_threads import ThreadManager +from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem from memos.mem_scheduler.utils.metrics import MetricsRegistry @@ -32,13 +34,23 @@ class SchedulerDispatcher(BaseSchedulerModule): - Thread race competition for parallel task execution """ - def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): + def __init__( + self, + max_workers: int = 30, + memos_message_queue: Any | None = None, + use_redis_queue: bool | None = None, + enable_parallel_dispatch: bool = True, + config=None, + ): super().__init__() self.config = config # Main dispatcher thread pool self.max_workers = max_workers + self.memos_message_queue = memos_message_queue + self.use_redis_queue = use_redis_queue + # Get multi-task timeout from config self.multi_task_running_timeout = ( self.config.get("multi_task_running_timeout") if self.config else None @@ -73,6 +85,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): self._completed_tasks = [] self.completed_tasks_max_show_size = 10 + # Configure shutdown wait behavior from config or default + self.stop_wait = ( + self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT + ) + self.metrics = MetricsRegistry( topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50) ) @@ -131,6 +148,19 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): # --- mark done --- for m in messages: self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + + # acknowledge redis messages + + if ( + self.use_redis_queue + and self.memos_message_queue is not None + and isinstance(self.memos_message_queue, SchedulerRedisQueue) + ): + for msg in messages: + redis_message_id = msg.redis_message_id + # Acknowledge message processing + self.memos_message_queue.ack_message(redis_message_id=redis_message_id) + # Mark task as completed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: @@ -138,7 +168,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): del self._running_tasks[task_item.item_id] self._completed_tasks.append(task_item) if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks[-self.completed_tasks_max_show_size :] + self._completed_tasks.pop(0) logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -152,7 +182,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): task_item.mark_failed(str(e)) del self._running_tasks[task_item.item_id] if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks[-self.completed_tasks_max_show_size :] + self._completed_tasks.pop(0) logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -381,17 +411,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): wrapped_handler = self._create_task_wrapper(handler, task_item) # dispatch to different handler - logger.debug( - f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." - ) - logger.info(f"Task started: {task_item.get_execution_info()}") - + logger.debug(f"Task started: {task_item.get_execution_info()}") if self.enable_parallel_dispatch and self.dispatcher_executor is not None: # Capture variables in lambda to avoid loop variable issues - future = self.dispatcher_executor.submit(wrapped_handler, msgs) - self._futures.add(future) - future.add_done_callback(self._handle_future_result) - logger.info(f"Dispatched {len(msgs)} message(s) as future task") + _ = self.dispatcher_executor.submit(wrapped_handler, msgs) + logger.info( + f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." + ) else: wrapped_handler(msgs) @@ -484,17 +510,9 @@ def shutdown(self) -> None: """Gracefully shutdown the dispatcher.""" self._running = False - if self.dispatcher_executor is not None: - # Cancel pending tasks - cancelled = 0 - for future in self._futures: - if future.cancel(): - cancelled += 1 - logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks") - # Shutdown executor try: - self.dispatcher_executor.shutdown(wait=True) + self.dispatcher_executor.shutdown(wait=self.stop_wait, cancel_futures=True) except Exception as e: logger.error(f"Executor shutdown error: {e}", exc_info=True) finally: diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index b6f48d043..e4e7edb89 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -199,6 +199,9 @@ class AutoDroppingQueue(Queue[T]): """A thread-safe queue that automatically drops the oldest item when full.""" def __init__(self, maxsize: int = 0): + # If maxsize <= 0, set to 0 (unlimited queue size) + if maxsize <= 0: + maxsize = 0 super().__init__(maxsize=maxsize) def put(self, item: T, block: bool = False, timeout: float | None = None) -> None: @@ -218,7 +221,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # First try non-blocking put super().put(item, block=block, timeout=timeout) except Full: - # Remove oldest item and mark it done to avoid leaking unfinished_tasks + # Remove the oldest item and mark it done to avoid leaking unfinished_tasks with suppress(Empty): _ = self.get_nowait() # If the removed item had previously incremented unfinished_tasks, @@ -228,12 +231,70 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # Retry putting the new item super().put(item, block=block, timeout=timeout) + def get( + self, block: bool = True, timeout: float | None = None, batch_size: int | None = None + ) -> list[T] | T: + """Get items from the queue. + + Args: + block: Whether to block if no items are available (default: True) + timeout: Timeout in seconds for blocking operations (default: None) + batch_size: Number of items to retrieve (default: 1) + + Returns: + List of items (always returns a list for consistency) + + Raises: + Empty: If no items are available and block=False or timeout expires + """ + + if batch_size is None: + return super().get(block=block, timeout=timeout) + items = [] + for _ in range(batch_size): + try: + items.append(super().get(block=block, timeout=timeout)) + except Empty: + if not items and block: + # If we haven't gotten any items and we're blocking, re-raise Empty + raise + break + return items + + def get_nowait(self, batch_size: int | None = None) -> list[T]: + """Get items from the queue without blocking. + + Args: + batch_size: Number of items to retrieve (default: 1) + + Returns: + List of items (always returns a list for consistency) + """ + if batch_size is None: + return super().get_nowait() + + items = [] + for _ in range(batch_size): + try: + items.append(super().get_nowait()) + except Empty: + break + return items + def get_queue_content_without_pop(self) -> list[T]: """Return a copy of the queue's contents without modifying it.""" # Ensure a consistent snapshot by holding the mutex with self.mutex: return list(self.queue) + def qsize(self) -> int: + """Return the approximate size of the queue. + + Returns: + Number of items currently in the queue + """ + return super().qsize() + def clear(self) -> None: """Remove all items from the queue. diff --git a/src/memos/mem_scheduler/general_modules/redis_queue.py b/src/memos/mem_scheduler/general_modules/redis_queue.py new file mode 100644 index 000000000..c10765d05 --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/redis_queue.py @@ -0,0 +1,460 @@ +""" +Redis Queue implementation for SchedulerMessageItem objects. + +This module provides a Redis-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +import time + +from collections.abc import Callable +from uuid import uuid4 + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule + + +logger = get_logger(__name__) + + +class SchedulerRedisQueue(RedisSchedulerModule): + """ + Redis-based queue for storing and processing SchedulerMessageItem objects. + + This class provides a Redis Stream-based implementation that can replace + the local memos_message_queue functionality, offering better scalability + and persistence for message processing. + + Inherits from RedisSchedulerModule to leverage existing Redis connection + and initialization functionality. + """ + + def __init__( + self, + stream_name: str = "scheduler:messages:stream", + consumer_group: str = "scheduler_group", + consumer_name: str | None = "scheduler_consumer", + max_len: int = 10000, + maxsize: int = 0, # For Queue compatibility + auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages + ): + """ + Initialize the Redis queue. + + Args: + stream_name: Name of the Redis stream + consumer_group: Name of the consumer group + consumer_name: Name of the consumer (auto-generated if None) + max_len: Maximum length of the stream (for memory management) + maxsize: Maximum size of the queue (for Queue compatibility, ignored) + auto_delete_acked: Whether to automatically delete acknowledged messages from stream + """ + super().__init__() + + # If maxsize <= 0, set to None (unlimited queue size) + if maxsize <= 0: + maxsize = 0 + + # Stream configuration + self.stream_name = stream_name + self.consumer_group = consumer_group + self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" + self.max_len = max_len + self.maxsize = maxsize # For Queue compatibility + self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages + + # Consumer state + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + + # Connection state + self._is_connected = False + + # Task tracking for mem_scheduler_wait compatibility + self._unfinished_tasks = 0 + + # Auto-initialize Redis connection + if self.auto_initialize_redis(): + self._is_connected = True + self._ensure_consumer_group() + + def _ensure_consumer_group(self) -> None: + """Ensure the consumer group exists for the stream.""" + if not self._redis_conn: + return + + try: + self._redis_conn.xgroup_create( + self.stream_name, self.consumer_group, id="0", mkstream=True + ) + logger.debug( + f"Created consumer group '{self.consumer_group}' for stream '{self.stream_name}'" + ) + except Exception as e: + # Check if it's a "consumer group already exists" error + error_msg = str(e).lower() + if "busygroup" in error_msg or "already exists" in error_msg: + logger.info( + f"Consumer group '{self.consumer_group}' already exists for stream '{self.stream_name}'" + ) + else: + logger.error(f"Error creating consumer group: {e}", exc_info=True) + + def put( + self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None + ) -> None: + """ + Add a message to the Redis queue (Queue-compatible interface). + + Args: + message: SchedulerMessageItem to add to the queue + block: Ignored for Redis implementation (always non-blocking) + timeout: Ignored for Redis implementation + + Raises: + ConnectionError: If not connected to Redis + TypeError: If message is not a ScheduleMessageItem + """ + if not self._redis_conn: + raise ConnectionError("Not connected to Redis. Redis connection not available.") + + if not isinstance(message, ScheduleMessageItem): + raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}") + + try: + # Convert message to dictionary for Redis storage + message_data = message.to_dict() + + # Add to Redis stream with automatic trimming + message_id = self._redis_conn.xadd( + self.stream_name, message_data, maxlen=self.max_len, approximate=True + ) + + logger.info( + f"Added message {message_id} to Redis stream: {message.label} - {message.content[:100]}..." + ) + + except Exception as e: + logger.error(f"Failed to add message to Redis queue: {e}") + raise + + def put_nowait(self, message: ScheduleMessageItem) -> None: + """ + Add a message to the Redis queue without blocking (Queue-compatible interface). + + Args: + message: SchedulerMessageItem to add to the queue + """ + self.put(message, block=False) + + def ack_message(self, redis_message_id): + self.redis.xack(self.stream_name, self.consumer_group, redis_message_id) + + # Optionally delete the message from the stream to keep it clean + if self.auto_delete_acked: + try: + self._redis_conn.xdel(self.stream_name, redis_message_id) + logger.info(f"Successfully delete acknowledged message {redis_message_id}") + except Exception as e: + logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}") + + def get( + self, + block: bool = True, + timeout: float | None = None, + batch_size: int | None = None, + ) -> list[ScheduleMessageItem]: + if not self._redis_conn: + raise ConnectionError("Not connected to Redis. Redis connection not available.") + + try: + # Calculate timeout for Redis + redis_timeout = None + if block and timeout is not None: + redis_timeout = int(timeout * 1000) + elif not block: + redis_timeout = None # Non-blocking + + # Read messages from the consumer group + try: + messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {self.stream_name: ">"}, + count=batch_size if not batch_size else 1, + block=redis_timeout, + ) + except Exception as read_err: + # Handle missing group/stream by creating and retrying once + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{self.stream_name}/{self.consumer_group}'. Attempting to create and retry." + ) + messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {self.stream_name: ">"}, + count=batch_size if not batch_size else 1, + block=redis_timeout, + ) + else: + raise + result_messages = [] + + for _stream, stream_messages in messages: + for message_id, fields in stream_messages: + try: + # Convert Redis message back to SchedulerMessageItem + message = ScheduleMessageItem.from_dict(fields) + message.redis_message_id = message_id + + result_messages.append(message) + + except Exception as e: + logger.error(f"Failed to parse message {message_id}: {e}") + + # Always return a list for consistency + if not result_messages: + if not block: + return [] # Return empty list for non-blocking calls + else: + # If no messages were found, raise Empty exception + from queue import Empty + + raise Empty("No messages available in Redis queue") + + return result_messages if batch_size is not None else result_messages[0] + + except Exception as e: + if "Empty" in str(type(e).__name__): + raise + logger.error(f"Failed to get message from Redis queue: {e}") + raise + + def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + """ + Get messages from the Redis queue without blocking (Queue-compatible interface). + + Returns: + List of SchedulerMessageItem objects + + Raises: + Empty: If no message is available + """ + return self.get(block=False, batch_size=batch_size) + + def qsize(self) -> int: + """ + Get the current size of the Redis queue (Queue-compatible interface). + + Returns the number of pending (unacknowledged) messages in the consumer group, + which represents the actual queue size for processing. + + Returns: + Number of pending messages in the queue + """ + if not self._redis_conn: + return 0 + + try: + # Get pending messages info for the consumer group + # XPENDING returns info about pending messages that haven't been acknowledged + pending_info = self._redis_conn.xpending(self.stream_name, self.consumer_group) + + # pending_info[0] contains the count of pending messages + if pending_info and len(pending_info) > 0 and pending_info[0] is not None: + pending_count = int(pending_info[0]) + if pending_count > 0: + return pending_count + + # If no pending messages, check if there are new messages in the stream + # that haven't been read by any consumer yet + try: + # Get the last delivered ID for the consumer group + groups_info = self._redis_conn.xinfo_groups(self.stream_name) + if not groups_info: + # No groups exist, check total stream length + return self._redis_conn.xlen(self.stream_name) or 0 + + last_delivered_id = "0-0" + + for group_info in groups_info: + if group_info and group_info.get("name") == self.consumer_group: + last_delivered_id = group_info.get("last-delivered-id", "0-0") + break + + # Count messages after the last delivered ID + new_messages = self._redis_conn.xrange( + self.stream_name, + f"({last_delivered_id}", # Exclusive start + "+", # End at the latest message + count=1000, # Limit to avoid memory issues + ) + + return len(new_messages) if new_messages else 0 + + except Exception as inner_e: + logger.debug(f"Failed to get new messages count: {inner_e}") + # Fallback: return stream length + try: + stream_len = self._redis_conn.xlen(self.stream_name) + return stream_len if stream_len is not None else 0 + except Exception: + return 0 + + except Exception as e: + logger.debug(f"Failed to get Redis queue size via XPENDING: {e}") + # Fallback to stream length if pending check fails + try: + stream_len = self._redis_conn.xlen(self.stream_name) + return stream_len if stream_len is not None else 0 + except Exception as fallback_e: + logger.error(f"Failed to get Redis queue size (all methods failed): {fallback_e}") + return 0 + + def size(self) -> int: + """ + Get the current size of the Redis queue (alias for qsize). + + Returns: + Number of messages in the queue + """ + return self.qsize() + + def empty(self) -> bool: + """ + Check if the Redis queue is empty (Queue-compatible interface). + + Returns: + True if the queue is empty, False otherwise + """ + return self.qsize() == 0 + + def full(self) -> bool: + """ + Check if the Redis queue is full (Queue-compatible interface). + + For Redis streams, we consider the queue full if it exceeds maxsize. + If maxsize is 0 or None, the queue is never considered full. + + Returns: + True if the queue is full, False otherwise + """ + if self.maxsize <= 0: + return False + return self.qsize() >= self.maxsize + + def join(self) -> None: + """ + Block until all items in the queue have been gotten and processed (Queue-compatible interface). + + For Redis streams, this would require tracking pending messages, + which is complex. For now, this is a no-op. + """ + + def clear(self) -> None: + """Clear all messages from the queue.""" + if not self._is_connected or not self._redis_conn: + return + + try: + # Delete the entire stream + self._redis_conn.delete(self.stream_name) + logger.info(f"Cleared Redis stream: {self.stream_name}") + + # Recreate the consumer group + self._ensure_consumer_group() + except Exception as e: + logger.error(f"Failed to clear Redis queue: {e}") + + def start_listening( + self, + handler: Callable[[ScheduleMessageItem], None], + batch_size: int = 10, + poll_interval: float = 0.1, + ) -> None: + """ + Start listening for messages and process them with the provided handler. + + Args: + handler: Function to call for each received message + batch_size: Number of messages to process in each batch + poll_interval: Interval between polling attempts in seconds + """ + if not self._is_connected: + raise ConnectionError("Not connected to Redis. Call connect() first.") + + self._message_handler = handler + self._is_listening = True + + logger.info(f"Started listening on Redis stream: {self.stream_name}") + + try: + while self._is_listening: + messages = self.get(timeout=poll_interval, count=batch_size) + + for message in messages: + try: + self._message_handler(message) + except Exception as e: + logger.error(f"Error processing message {message.item_id}: {e}") + + # Small sleep to prevent excessive CPU usage + if not messages: + time.sleep(poll_interval) + + except KeyboardInterrupt: + logger.info("Received interrupt signal, stopping listener") + except Exception as e: + logger.error(f"Error in message listener: {e}") + finally: + self._is_listening = False + logger.info("Stopped listening for messages") + + def stop_listening(self) -> None: + """Stop the message listener.""" + self._is_listening = False + logger.info("Requested stop for message listener") + + def connect(self) -> None: + """Establish connection to Redis and set up the queue.""" + if self._redis_conn is not None: + try: + # Test the connection + self._redis_conn.ping() + self._is_connected = True + logger.debug("Redis connection established successfully") + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + self._is_connected = False + else: + logger.error("Redis connection not initialized") + self._is_connected = False + + def disconnect(self) -> None: + """Disconnect from Redis and clean up resources.""" + self._is_connected = False + if self._is_listening: + self.stop_listening() + logger.debug("Disconnected from Redis") + + def __enter__(self): + """Context manager entry.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop_listening() + self.disconnect() + + def __del__(self): + """Cleanup when object is destroyed.""" + if self._is_connected: + self.disconnect() + + @property + def unfinished_tasks(self) -> int: + return self.qsize() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 6840adc2b..32fefce63 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -51,7 +51,7 @@ def __init__(self, config: GeneralSchedulerConfig): def long_memory_update_process( self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] ): - mem_cube = messages[0].mem_cube + mem_cube = self.current_mem_cube # for status update self._set_current_context_from_message(msg=messages[0]) @@ -140,7 +140,7 @@ def long_memory_update_process( label=QUERY_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=messages[0].mem_cube, + mem_cube=self.current_mem_cube, ) def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -212,7 +212,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) userinput_memory_ids = [] - mem_cube = msg.mem_cube + mem_cube = self.current_mem_cube for memory_id in userinput_memory_ids: try: mem_item: TextualMemoryItem = mem_cube.text_mem.get( @@ -234,7 +234,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: memory_type=mem_type, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=msg.mem_cube, + mem_cube=self.current_mem_cube, log_func_callback=self._submit_web_logs, ) @@ -248,7 +248,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content user_name = message.user_name @@ -412,7 +412,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content user_name = message.user_name @@ -516,7 +516,7 @@ def process_message(message: ScheduleMessageItem): user_id = message.user_id session_id = message.session_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content messages_list = json.loads(content) diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py index e18c6e51a..25b9a98f3 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py +++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py @@ -2,7 +2,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.mem_scheduler.utils.misc_utils import extract_json_obj from memos.memories.textual.tree import TextualMemoryItem @@ -66,7 +66,7 @@ def filter_unrelated_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") relevant_indices = response["relevant_memories"] filtered_count = response["filtered_count"] @@ -164,7 +164,7 @@ def filter_redundant_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") kept_indices = response["kept_memories"] redundant_groups = response.get("redundant_groups", []) @@ -226,8 +226,6 @@ def filter_unrelated_and_redundant_memories( Note: If LLM filtering fails, returns all memories (conservative approach) """ - success_flag = False - if not memories: logger.info("No memories to filter for unrelated and redundant - returning empty list") return [], True @@ -265,7 +263,7 @@ def filter_unrelated_and_redundant_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") kept_indices = response["kept_memories"] unrelated_removed_count = response.get("unrelated_removed_count", 0) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index b766f0010..848b1d257 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,9 +1,14 @@ +from concurrent.futures import as_completed + from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.context.context import ContextThreadPoolExecutor from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) @@ -12,11 +17,11 @@ filter_vector_based_similar_memories, transform_name_to_key, ) -from memos.mem_scheduler.utils.misc_utils import ( - extract_json_dict, -) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer +from memos.memories.textual.item import TextualMemoryMetadata from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +# Extract JSON response from .memory_filter import MemoryFilter @@ -30,12 +35,213 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): # hyper-parameters self.filter_similarity_threshold = 0.75 self.filter_min_length_threshold = 6 - - self.config: BaseSchedulerConfig = config + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) self.process_llm = process_llm + self.config = config - # Initialize memory filter - self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + # Configure enhancement batching & retries from config with safe defaults + self.batch_size: int | None = getattr( + config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + ) + self.retries: int = getattr( + config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + ) + + def evaluate_memory_answer_ability( + self, query: str, memory_texts: list[str], top_k: int | None = None + ) -> bool: + limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts + # Build prompt using the template + prompt = self.build_prompt( + template_name="memory_answer_ability_evaluation", + query=query, + memory_list="\n".join([f"- {memory}" for memory in limited_memories]) + if limited_memories + else "No memories available", + ) + + # Use the process LLM to generate response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + try: + result = extract_json_obj(response) + + # Validate response structure + if "result" in result: + logger.info( + f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}" + ) + return result["result"] + else: + logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}") + return False + + except Exception as e: + logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...") + # Fallback: return False if we can't determine answer ability + return False + + # ---------------------- Enhancement helpers ---------------------- + def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: + if len(query_history) == 1: + query_history = query_history[0] + else: + query_history = ( + [f"[{i}] {query}" for i, query in enumerate(query_history)] + if len(query_history) > 1 + else query_history[0] + ) + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + return self.build_prompt( + "memory_enhancement", + query_history=query_history, + memories=text_memories, + ) + + def _process_enhancement_batch( + self, + batch_index: int, + query_history: list[str], + memories: list[TextualMemoryItem], + retries: int, + ) -> tuple[list[TextualMemoryItem], bool]: + attempt = 0 + text_memories = [one.memory for one in memories] + while attempt <= max(0, retries) + 1: + try: + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) + logger.debug( + f"[Enhance][batch={batch_index}] Prompt (first 200 chars, len={len(prompt)}): " + f"{prompt[:200]}]..." + ) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug( + f"[Enhance][batch={batch_index}] Response (first 200 chars): {response}..." + ) + + processed_text_memories = extract_list_items_in_answer(response) + if len(processed_text_memories) == len(memories): + # Update + for i, new_mem in enumerate(processed_text_memories): + memories[i].memory = new_mem + enhanced_memories = memories + else: + # create new + enhanced_memories = [] + user_id = memories[0].metadata.user_id + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + ) + ) + enhanced_memories = ( + enhanced_memories + memories[: len(memories) - len(enhanced_memories)] + ) + + logger.info( + f"[Enhance]: processed_text_memories: {len(processed_text_memories)}; padded with original memories to preserve total count" + ) + + return enhanced_memories, True + except Exception as e: + attempt += 1 + logger.debug( + f"[Enhance][batch={batch_index}] ๐Ÿ” retry {attempt}/{max(1, retries) + 1} failed: {e}" + ) + logger.error( + f"Fail to run memory enhancement; original memories: {memories}", exc_info=True + ) + return memories, False + + @staticmethod + def _split_batches( + memories: list[TextualMemoryItem], batch_size: int + ) -> list[tuple[int, int, list[TextualMemoryItem]]]: + batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] + start = 0 + n = len(memories) + while start < n: + end = min(start + batch_size, n) + batches.append((start, end, memories[start:end])) + start = end + return batches + + def enhance_memories_with_query( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Enhance memories by adding context and making connections to better answer queries. + + Args: + query_history: List of user queries in chronological order + memories: List of memory items to enhance + + Returns: + Tuple of (enhanced_memories, success_flag) + """ + if not memories: + logger.warning("[Enhance] โš ๏ธ skipped (no memories to process)") + return memories, True + + batch_size = self.batch_size + retries = self.retries + num_of_memories = len(memories) + try: + # no parallel + if batch_size is None or num_of_memories <= batch_size: + # Single batch path with retry + enhanced_memories, success_flag = self._process_enhancement_batch( + batch_index=0, + query_history=query_history, + memories=memories, + retries=retries, + ) + + all_success = success_flag + else: + # parallel running batches + # Split into batches preserving order + batches = self._split_batches(memories=memories, batch_size=batch_size) + + # Process batches concurrently + all_success = True + failed_batches = 0 + with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: + future_map = { + executor.submit( + self._process_enhancement_batch, bi, query_history, texts, retries + ): (bi, s, e) + for bi, (s, e, texts) in enumerate(batches) + } + enhanced_memories = [] + for fut in as_completed(future_map): + bi, s, e = future_map[fut] + + batch_memories, ok = fut.result() + enhanced_memories.extend(batch_memories) + if not ok: + all_success = False + failed_batches += 1 + logger.info( + f"[Enhance] โœ… multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |" + f" failed_batches={failed_batches} | success={all_success}" + ) + + except Exception as e: + logger.error(f"[Enhance] โŒ fatal error: {e}", exc_info=True) + all_success = False + enhanced_memories = memories + + if len(enhanced_memories) == 0: + enhanced_memories = memories + logger.error("[Enhance] โŒ fatal error: enhanced_memories is empty", exc_info=True) + return enhanced_memories, all_success def search( self, @@ -115,7 +321,7 @@ def rerank_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) new_order = response["new_order"][:top_k] text_memories_with_new_order = [original_memories[idx] for idx in new_order] logger.info( diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 46c4e2d49..99982d2e6 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -11,6 +11,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL, DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, + DEFAULT_STOP_WAIT, DEFAULT_STUCK_THREAD_TOLERANCE, ) from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -46,6 +47,11 @@ def __init__(self, config: BaseSchedulerConfig): self.dispatcher: SchedulerDispatcher | None = None self.dispatcher_pool_name = "dispatcher" + # Configure shutdown wait behavior from config or default + self.stop_wait = ( + self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT + ) + def initialize(self, dispatcher: SchedulerDispatcher): self.dispatcher = dispatcher self.register_pool( @@ -367,12 +373,9 @@ def stop(self) -> None: if not executor._shutdown: # pylint: disable=protected-access try: logger.info(f"Shutting down thread pool '{name}'") - executor.shutdown(wait=True, cancel_futures=True) + executor.shutdown(wait=self.stop_wait, cancel_futures=True) logger.info(f"Successfully shut down thread pool '{name}'") except Exception as e: logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) - # Clear the pool registry - self._pools.clear() - logger.info("Thread pool monitor and all pools stopped") diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index a789d581e..a5f1c0097 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -29,7 +29,7 @@ QueryMonitorQueue, ) from memos.mem_scheduler.utils.db_utils import get_utc_now -from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.mem_scheduler.utils.misc_utils import extract_json_obj from memos.memories.textual.tree import TreeTextMemory @@ -92,7 +92,7 @@ def extract_query_keywords(self, query: str) -> list: llm_response = self._process_llm.generate([{"role": "user", "content": prompt}]) try: # Parse JSON output from LLM response - keywords = extract_json_dict(llm_response) + keywords = extract_json_obj(llm_response) assert isinstance(keywords, list) except Exception as e: logger.error( @@ -206,7 +206,7 @@ def update_working_memory_monitors( self.working_mem_monitor_capacity = min( DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, ( - text_mem_base.memory_manager.memory_size["WorkingMemory"] + int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) + self.partial_retention_number ), ) @@ -353,7 +353,7 @@ def detect_intent( ) response = self._process_llm.generate([{"role": "user", "content": prompt}]) try: - response = extract_json_dict(response) + response = extract_json_obj(response) assert ("trigger_retrieval" in response) and ("missing_evidences" in response) except Exception: logger.error(f"Fail to extract json dict from response: {response}") diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index a087ab2df..b62b1e51d 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -52,11 +52,24 @@ def __init__(self, config: GeneralSchedulerConfig): API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, } ) + self.searcher = None + self.reranker = None + self.text_mem = None + + def init_mem_cube(self, mem_cube): + self.current_mem_cube = mem_cube + self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=False, + moscube=False, + ) + self.reranker: HTTPBGEReranker = self.text_mem.reranker def submit_memory_history_async_task( self, search_req: APISearchRequest, user_context: UserContext, + memories_to_store: dict | None = None, session_id: str | None = None, ): # Create message for async fine search @@ -71,19 +84,16 @@ def submit_memory_history_async_task( "chat_history": search_req.chat_history, }, "user_context": {"mem_cube_id": user_context.mem_cube_id}, + "memories_to_store": memories_to_store, } async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}" - # Get mem_cube for the message - mem_cube = self.current_mem_cube - message = ScheduleMessageItem( item_id=async_task_id, user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id, label=API_MIX_SEARCH_LABEL, - mem_cube=mem_cube, content=json.dumps(message_content), timestamp=get_utc_now(), ) @@ -127,33 +137,26 @@ def mix_search_memories( self, search_req: APISearchRequest, user_context: UserContext, - ): + ) -> list[dict[str, Any]]: """ Mix search memories: fast search + async fine search """ # Get mem_cube for fast search - mem_cube = self.current_mem_cube - target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - text_mem: TreeTextMemory = mem_cube.text_mem - searcher: Searcher = text_mem.get_searcher( - manual_close_internet=not search_req.internet_search, - moscube=False, - ) # Rerank Memories - reranker expects TextualMemoryItem objects - reranker: HTTPBGEReranker = text_mem.reranker + info = { "user_id": search_req.user_id, "session_id": target_session_id, "chat_history": search_req.chat_history, } - fast_retrieved_memories = searcher.retrieve( + fast_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -164,13 +167,7 @@ def mix_search_memories( info=info, ) - self.submit_memory_history_async_task( - search_req=search_req, - user_context=user_context, - session_id=search_req.session_id, - ) - - # Try to get pre-computed fine memories if available + # Try to get pre-computed memories if available history_memories = self.api_module.get_history_memories( user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id, @@ -178,7 +175,7 @@ def mix_search_memories( ) if not history_memories: - fast_memories = searcher.post_retrieve( + fast_memories = self.searcher.post_retrieve( retrieved_results=fast_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, @@ -187,39 +184,72 @@ def mix_search_memories( # Format fast memories for return formatted_memories = [format_textual_memory_item(data) for data in fast_memories] return formatted_memories + else: + # if history memories can directly answer + sorted_history_memories = self.reranker.rerank( + query=search_req.query, # Use search_req.query instead of undefined query + graph_results=history_memories, # Pass TextualMemoryItem objects directly + top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k + search_filter=search_filter, + ) - sorted_history_memories = reranker.rerank( - query=search_req.query, # Use search_req.query instead of undefined query - graph_results=history_memories, # Pass TextualMemoryItem objects directly - top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k - search_filter=search_filter, - ) + processed_hist_mem = self.searcher.post_retrieve( + retrieved_results=sorted_history_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) - sorted_results = fast_retrieved_memories + sorted_history_memories - final_results = searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) + can_answer = self.retriever.evaluate_memory_answer_ability( + query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem] + ) - formatted_memories = [ - format_textual_memory_item(item) for item in final_results[: search_req.top_k] - ] + if can_answer: + sorted_results = fast_retrieved_memories + sorted_history_memories + combined_results = self.searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + memories = combined_results[: search_req.top_k] + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("can_answer") + else: + sorted_results = fast_retrieved_memories + sorted_history_memories + combined_results = self.searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + enhanced_results, _ = self.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=combined_results, + ) + memories = enhanced_results[: search_req.top_k] + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("cannot answer") + + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + memories_to_store={ + "memories": [one.to_dict() for one in memories], + "formatted_memories": formatted_memories, + }, + ) - return formatted_memories + return formatted_memories def update_search_memories_to_redis( self, messages: list[ScheduleMessageItem], ): - mem_cube: NaiveMemCube = self.current_mem_cube - for msg in messages: content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - session_id = search_req.get("session_id") if session_id: if session_id not in self.session_counter: @@ -237,13 +267,20 @@ def update_search_memories_to_redis( else: session_turn = 0 - memories: list[TextualMemoryItem] = self.search_memories( - search_req=APISearchRequest(**content_dict["search_req"]), - user_context=UserContext(**content_dict["user_context"]), - mem_cube=mem_cube, - mode=SearchMode.FAST, - ) - formatted_memories = [format_textual_memory_item(data) for data in memories] + memories_to_store = content_dict["memories_to_store"] + if memories_to_store is None: + memories: list[TextualMemoryItem] = self.search_memories( + search_req=APISearchRequest(**content_dict["search_req"]), + user_context=UserContext(**content_dict["user_context"]), + mem_cube=self.current_mem_cube, + mode=SearchMode.FAST, + ) + formatted_memories = [format_textual_memory_item(data) for data in memories] + else: + memories = [ + TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"] + ] + formatted_memories = memories_to_store["formatted_memories"] # Sync search data to Redis self.api_module.sync_search_data( diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f3d2191f8..7f2c09b7d 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -6,6 +6,7 @@ class SearchMode(str, Enum): """Enumeration for search modes.""" + NOT_INITIALIZED = "not_initialized" FAST = "fast" FINE = "fine" MIXTURE = "mixture" @@ -32,14 +33,18 @@ class SearchMode(str, Enum): DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 +DEFAULT_CONSUME_BATCH = 1 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 -DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 1000000 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 0 DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = False +DEFAULT_USE_REDIS_QUEUE = True DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 +DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 10 +DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 +DEFAULT_STOP_WAIT = False # startup mode configuration STARTUP_BY_THREAD = "thread" @@ -64,6 +69,7 @@ class SearchMode(str, Enum): MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType" DEFAULT_MAX_QUERY_KEY_WORDS = 1000 DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] +DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 # new types diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 7f328474f..f1d48f3f1 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -2,11 +2,10 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field, field_serializer +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypedDict from memos.log import get_logger -from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -34,22 +33,19 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) + redis_message_id: str = Field(default="", description="the message get from redis stream") user_id: str = Field(..., description="user id") mem_cube_id: str = Field(..., description="memcube id") + session_id: str = Field(default="", description="Session ID for soft-filtering memories") label: str = Field(..., description="Label of the schedule message") - mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" ) - user_name: str | None = Field( - default=None, + user_name: str = Field( + default="", description="user name / display name (optional)", ) - session_id: str | None = Field( - default=None, - description="session_id (optional)", - ) # Pydantic V2 model configuration model_config = ConfigDict( @@ -65,7 +61,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "user_id": "user123", # Example user identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value - "mem_cube": "obj of GeneralMemCube", # Added mem_cube example "content": "sample content", # Example message content "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example "user_name": "Alice", # Added username example @@ -73,13 +68,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): }, ) - @field_serializer("mem_cube") - def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str: - """Custom serializer for BaseMemCube objects to string representation""" - if isinstance(cube, str): - return cube - return f"<{type(cube).__name__}:{id(cube)}>" - def to_dict(self) -> dict: """Convert model to dictionary suitable for Redis Stream""" return { @@ -101,7 +89,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": user_id=data["user_id"], mem_cube_id=data["cube_id"], label=data["label"], - mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), user_name=data.get("user_name"), diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py index 5155c98b3..45abc5b36 100644 --- a/src/memos/mem_scheduler/utils/metrics.py +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -6,10 +6,14 @@ from dataclasses import dataclass, field +from memos.log import get_logger + # ==== global window config ==== WINDOW_SEC = 120 # 2 minutes sliding window +logger = get_logger(__name__) + # ---------- O(1) EWMA ---------- class Ewma: @@ -187,7 +191,7 @@ def on_enqueue( old_lam = ls.lambda_ewma.value_at(now) ls.lambda_ewma.update(inst_rate, now) new_lam = ls.lambda_ewma.value_at(now) - print( + logger.info( f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else 'โ€”'}s inst={inst_rate:.3f} ฮป {old_lam:.3f}โ†’{new_lam:.3f}" ) self._label_topk[label].add(mem_cube_id) @@ -225,7 +229,7 @@ def on_done( old_mu = ls.mu_ewma.value_at(now) ls.mu_ewma.update(inst_rate, now) new_mu = ls.mu_ewma.value_at(now) - print( + logger.info( f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else 'โ€”'}s inst={inst_rate:.3f} ฮผ {old_mu:.3f}โ†’{new_mu:.3f}" ) ds = self._detail_stats.get((label, mem_cube_id)) diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index aa9b5c489..cce1286bb 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -1,5 +1,6 @@ import json import re +import traceback from functools import wraps from pathlib import Path @@ -12,7 +13,7 @@ logger = get_logger(__name__) -def extract_json_dict(text: str): +def extract_json_obj(text: str): """ Safely extracts JSON from LLM response text with robust error handling. @@ -40,7 +41,7 @@ def extract_json_dict(text: str): try: return json.loads(text.strip()) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) # Fallback 1: Extract JSON using regex json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]" @@ -49,7 +50,7 @@ def extract_json_dict(text: str): try: return json.loads(matches[0]) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) # Fallback 2: Handle malformed JSON (common LLM issues) try: @@ -57,10 +58,125 @@ def extract_json_dict(text: str): text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text) return json.loads(text) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}") + logger.error("Full traceback:\n" + traceback.format_exc()) raise ValueError(text) from e +def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]: + """ + Extract bullet list items from LLM output where each item is on a single line + starting with a given bullet prefix (default: "- "). + + This function is designed to be robust to common LLM formatting variations, + following similar normalization practices as `extract_json_obj`. + + Behavior: + - Strips common code-fence markers (```json, ```python, ``` etc.). + - Collects all lines that start with any of the provided `bullet_prefixes`. + - Tolerates the "โ€ข " bullet as a loose fallback. + - Unescapes common sequences like "\\n" and "\\t" within items. + - If no bullet lines are found, falls back to attempting to parse a JSON array + (using `extract_json_obj`) and returns its string elements. + + Args: + text: Raw text response from LLM. + bullet_prefixes: Tuple of accepted bullet line prefixes. + + Returns: + List of extracted items (strings). Returns an empty list if none can be parsed. + """ + if not text: + return [] + + # Normalize the text similar to extract_json_obj + normalized = text.strip() + patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"] + for pattern in patterns_to_remove: + normalized = normalized.replace(pattern, "") + normalized = normalized.replace("\r\n", "\n") + + lines = normalized.splitlines() + items: list[str] = [] + seen: set[str] = set() + + for raw in lines: + line = raw.strip() + if not line: + continue + + matched = False + for prefix in bullet_prefixes: + if line.startswith(prefix): + content = line[len(prefix) :].strip() + content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r") + if content and content not in seen: + items.append(content) + seen.add(content) + matched = True + break + + if matched: + continue + + if items: + return items + else: + logger.error(f"Fail to parse {text}") + + return [] + + +def extract_list_items_in_answer( + text: str, bullet_prefixes: tuple[str, ...] = ("- ",) +) -> list[str]: + """ + Extract list items specifically from content enclosed within `...` tags. + + - When one or more `...` blocks are present, concatenates their inner + contents with newlines and parses using `extract_list_items`. + - When no `` block is found, falls back to parsing the entire input with + `extract_list_items`. + - Case-insensitive matching of the `` tag. + + Args: + text: Raw text that may contain `...` blocks. + bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`). + + Returns: + List of extracted items (strings), or an empty list when nothing is parseable. + """ + if not text: + return [] + + try: + normalized = text.strip().replace("\r\n", "\n") + # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER + tag_variants = ["answer", "Answer", "ANSWER"] + matches: list[str] = [] + for tag in tag_variants: + matches = re.findall(rf"<{tag}>([\\s\\S]*?)", normalized) + if matches: + break + # Fallback: case-insensitive matching if none of the exact-case variants matched + if not matches: + matches = re.findall(r"([\\s\\S]*?)", normalized, flags=re.IGNORECASE) + + if matches: + combined = "\n".join(m.strip() for m in matches if m is not None) + return extract_list_items(combined, bullet_prefixes=bullet_prefixes) + + # Fallback: parse the whole text if tags are absent + return extract_list_items(normalized, bullet_prefixes=bullet_prefixes) + except Exception as e: + logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True) + # Final fallback: attempt direct list extraction + try: + return extract_list_items(text, bullet_prefixes=bullet_prefixes) + except Exception: + return [] + + def parse_yaml(yaml_file: str | Path): yaml_path = Path(yaml_file) if not yaml_path.is_file(): diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 5439af9c6..e79553f33 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -333,6 +333,15 @@ def redis_start_listening(self, handler: Callable | None = None): logger.warning("Listener is already running") return + # Check Redis connection before starting listener + if self.redis is None: + logger.warning( + "Redis connection is None, attempting to auto-initialize before starting listener..." + ) + if not self.auto_initialize_redis(): + logger.error("Failed to initialize Redis connection, cannot start listener") + return + if handler is None: handler = self.redis_consume_message_stream diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 55e33494c..b9814f079 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -22,6 +22,7 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm self.tokenizer = FastTokenizer() + self.retries = 1 def parse( self, @@ -103,18 +104,24 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: """ Parse LLM JSON output safely. """ - try: - context = kwargs.get("context", "") - response = response.replace("```", "").replace("json", "").strip() - response_json = eval(response) - return ParsedTaskGoal( - memories=response_json.get("memories", []), - keys=response_json.get("keys", []), - tags=response_json.get("tags", []), - rephrased_query=response_json.get("rephrased_instruction", None), - internet_search=response_json.get("internet_search", False), - goal_type=response_json.get("goal_type", "default"), - context=context, - ) - except Exception as e: - raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e + # Ensure at least one attempt + attempts = max(1, getattr(self, "retries", 1)) + + for attempt_times in range(attempts): + try: + context = kwargs.get("context", "") + response = response.replace("```", "").replace("json", "").strip() + response_json = eval(response) + return ParsedTaskGoal( + memories=response_json.get("memories", []), + keys=response_json.get("keys", []), + tags=response_json.get("tags", []), + rephrased_query=response_json.get("rephrased_instruction", None), + internet_search=response_json.get("internet_search", False), + goal_type=response_json.get("goal_type", "default"), + context=context, + ) + except Exception as e: + raise ValueError( + f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}" + ) from e diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index b4d091c1f..197a2c1a7 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -390,6 +390,45 @@ - Focus on whether the memories can fully answer the query without additional information """ +MEMORY_ENHANCEMENT_PROMPT = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform each raw memory into an enhanced version that preserves all relevant factual details and makes the information directly useful for answering the user's query. + +# CORE PRINCIPLE +Focus on **relevance** โ€” the enhanced memories should highlight, clarify, and preserve the information that most directly supports answering the current query. + +# RULES & THINKING STEPS +1. Read the user query carefully and identify what specific facts are needed to answer it. +2. Go through each memory and: + - Keep only details directly relevant to the query (dates, actions, entities, outcomes). + - Remove unrelated or background details. + - If nothing in a memory relates to the query, delete the entire memory. +3. Do not add or infer new facts. +4. Keep facts accurate and phrased clearly. +5. Each resulting line should stand alone as a usable fact for answering the query. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Available Memories +{memories} + +Answer: +""" + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, @@ -398,6 +437,7 @@ "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT, "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT, "memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT, + "memory_enhancement": MEMORY_ENHANCEMENT_PROMPT, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index e3064660b..fc154e013 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -90,7 +90,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg1", user_id="user1", - mem_cube="cube1", mem_cube_id="msg1", label="label1", content="Test content 1", @@ -99,7 +98,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg2", user_id="user1", - mem_cube="cube1", mem_cube_id="msg2", label="label2", content="Test content 2", @@ -108,7 +106,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg3", user_id="user2", - mem_cube="cube2", mem_cube_id="msg3", label="label1", content="Test content 3", @@ -193,46 +190,6 @@ def test_dispatch_serial(self): self.assertEqual(len(label2_messages), 1) self.assertEqual(label2_messages[0].item_id, "msg2") - def test_dispatch_parallel(self): - """Test dispatching messages in parallel mode.""" - # Create fresh mock handlers for this test - mock_handler1 = MagicMock() - mock_handler2 = MagicMock() - - # Create a new dispatcher for this test to avoid interference - parallel_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=True) - parallel_dispatcher.register_handler("label1", mock_handler1) - parallel_dispatcher.register_handler("label2", mock_handler2) - - # Dispatch messages - parallel_dispatcher.dispatch(self.test_messages) - - # Wait for all futures to complete - parallel_dispatcher.join(timeout=1.0) - - # Verify handlers were called - label1 handler should be called twice (for user1 and user2) - # label2 handler should be called once (only for user1) - self.assertEqual(mock_handler1.call_count, 2) # Called for user1/msg1 and user2/msg3 - mock_handler2.assert_called_once() # Called for user1/msg2 - - # Check that each handler received the correct messages - # For label1: should have two calls, each with one message - label1_calls = mock_handler1.call_args_list - self.assertEqual(len(label1_calls), 2) - - # Extract messages from calls - call1_messages = label1_calls[0][0][0] # First call, first argument (messages list) - call2_messages = label1_calls[1][0][0] # Second call, first argument (messages list) - - # Verify the messages in each call - self.assertEqual(len(call1_messages), 1) - self.assertEqual(len(call2_messages), 1) - - # For label2: should have one call with [msg2] - label2_messages = mock_handler2.call_args[0][0] - self.assertEqual(len(label2_messages), 1) - self.assertEqual(label2_messages[0].item_id, "msg2") - def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" # Check actual grouping logic diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 03a8e4318..fed1e8500 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -1,7 +1,6 @@ import sys import unittest -from contextlib import suppress from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, patch @@ -21,12 +20,9 @@ from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - STARTUP_BY_PROCESS, - STARTUP_BY_THREAD, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, - ScheduleMessageItem, ) from memos.memories.textual.tree import TreeTextMemory @@ -182,124 +178,6 @@ def test_submit_web_logs(self): self.assertTrue(hasattr(actual_message, "timestamp")) self.assertTrue(isinstance(actual_message.timestamp, datetime)) - def test_scheduler_startup_mode_default(self): - """Test that scheduler has default startup mode set to thread.""" - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_THREAD) - - def test_scheduler_startup_mode_thread(self): - """Test scheduler with thread startup mode.""" - # Set scheduler startup mode to thread - self.scheduler.scheduler_startup_mode = STARTUP_BY_THREAD - - # Start the scheduler - self.scheduler.start() - - # Verify that consumer thread is created and process is None - self.assertIsNotNone(self.scheduler._consumer_thread) - self.assertIsNone(self.scheduler._consumer_process) - self.assertTrue(self.scheduler._running) - - # Stop the scheduler - self.scheduler.stop() - - def test_redis_message_queue(self): - """Test Redis message queue functionality for sending and receiving messages.""" - import time - - from unittest.mock import MagicMock, patch - - # Mock Redis connection and operations - mock_redis = MagicMock() - mock_redis.xadd = MagicMock(return_value=b"1234567890-0") - - # Track received messages - received_messages = [] - - def redis_handler(messages: list[ScheduleMessageItem]) -> None: - """Handler for Redis messages.""" - received_messages.extend(messages) - - # Register Redis handler - redis_label = "test_redis" - handlers = {redis_label: redis_handler} - self.scheduler.register_handlers(handlers) - - # Enable Redis queue for this test - with ( - patch.object(self.scheduler, "use_redis_queue", True), - patch.object(self.scheduler, "_redis_conn", mock_redis), - ): - # Start scheduler - self.scheduler.start() - - # Create test message for Redis - redis_message = ScheduleMessageItem( - label=redis_label, - content="Redis test message", - user_id="redis_user", - mem_cube_id="redis_cube", - mem_cube="redis_mem_cube_obj", - timestamp=datetime.now(), - ) - - # Submit message to Redis queue - self.scheduler.submit_messages(redis_message) - - # Verify Redis xadd was called - mock_redis.xadd.assert_called_once() - call_args = mock_redis.xadd.call_args - self.assertEqual(call_args[0][0], "user:queries:stream") - - # Verify message data was serialized correctly - message_data = call_args[0][1] - self.assertEqual(message_data["label"], redis_label) - self.assertEqual(message_data["content"], "Redis test message") - self.assertEqual(message_data["user_id"], "redis_user") - self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id - - # Simulate Redis message consumption - # This would normally be handled by the Redis consumer in the scheduler - time.sleep(0.1) # Brief wait for async operations - - # Stop scheduler - self.scheduler.stop() - - print("Redis message queue test completed successfully!") - - # Removed test_robustness method - was too time-consuming for CI/CD pipeline - - def test_scheduler_startup_mode_process(self): - """Test scheduler with process startup mode.""" - # Set scheduler startup mode to process - self.scheduler.scheduler_startup_mode = STARTUP_BY_PROCESS - - # Start the scheduler - try: - self.scheduler.start() - - # Verify that consumer process is created and thread is None - self.assertIsNotNone(self.scheduler._consumer_process) - self.assertIsNone(self.scheduler._consumer_thread) - self.assertTrue(self.scheduler._running) - - except Exception as e: - # Process mode may fail due to pickling issues in test environment - # This is expected behavior - we just verify the startup mode is set correctly - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS) - print(f"Process mode test encountered expected pickling issue: {e}") - finally: - # Always attempt to stop the scheduler - with suppress(Exception): - self.scheduler.stop() - - # Verify cleanup attempt was made - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS) - - def test_scheduler_startup_mode_constants(self): - """Test that startup mode constants are properly defined.""" - self.assertEqual(STARTUP_BY_THREAD, "thread") - self.assertEqual(STARTUP_BY_PROCESS, "process") - def test_activation_memory_update(self): """Test activation memory update functionality with DynamicCache handling.""" if not self.RUN_ACTIVATION_MEMORY_TESTS: @@ -401,130 +279,3 @@ def test_dynamic_cache_layers_access(self): # If layers attribute doesn't exist, verify our fix handles this case print("โš ๏ธ DynamicCache doesn't have 'layers' attribute in this transformers version") print("โœ… Test passed - our code should handle this gracefully") - - def test_get_running_tasks_with_filter(self): - """Test get_running_tasks method with filter function.""" - # Mock dispatcher and its get_running_tasks method - mock_task_item1 = MagicMock() - mock_task_item1.item_id = "task_1" - mock_task_item1.user_id = "user_1" - mock_task_item1.mem_cube_id = "cube_1" - mock_task_item1.task_info = {"type": "query"} - mock_task_item1.task_name = "test_task_1" - mock_task_item1.start_time = datetime.now() - mock_task_item1.end_time = None - mock_task_item1.status = "running" - mock_task_item1.result = None - mock_task_item1.error_message = None - mock_task_item1.messages = [] - - # Define a filter function - def user_filter(task): - return task.user_id == "user_1" - - # Mock the filtered result (only task_1 matches the filter) - with patch.object( - self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1} - ) as mock_get_running_tasks: - # Call get_running_tasks with filter - result = self.scheduler.get_running_tasks(filter_func=user_filter) - - # Verify result - self.assertIsInstance(result, dict) - self.assertIn("task_1", result) - self.assertEqual(len(result), 1) - - # Verify dispatcher method was called with filter - mock_get_running_tasks.assert_called_once_with(filter_func=user_filter) - - def test_get_running_tasks_empty_result(self): - """Test get_running_tasks method when no tasks are running.""" - # Mock dispatcher to return empty dict - with patch.object( - self.scheduler.dispatcher, "get_running_tasks", return_value={} - ) as mock_get_running_tasks: - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify empty result - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 0) - - # Verify dispatcher method was called - mock_get_running_tasks.assert_called_once_with(filter_func=None) - - def test_get_running_tasks_no_dispatcher(self): - """Test get_running_tasks method when dispatcher is None.""" - # Temporarily set dispatcher to None - original_dispatcher = self.scheduler.dispatcher - self.scheduler.dispatcher = None - - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify empty result and warning behavior - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 0) - - # Restore dispatcher - self.scheduler.dispatcher = original_dispatcher - - def test_get_running_tasks_multiple_tasks(self): - """Test get_running_tasks method with multiple tasks.""" - # Mock multiple task items - mock_task_item1 = MagicMock() - mock_task_item1.item_id = "task_1" - mock_task_item1.user_id = "user_1" - mock_task_item1.mem_cube_id = "cube_1" - mock_task_item1.task_info = {"type": "query"} - mock_task_item1.task_name = "test_task_1" - mock_task_item1.start_time = datetime.now() - mock_task_item1.end_time = None - mock_task_item1.status = "running" - mock_task_item1.result = None - mock_task_item1.error_message = None - mock_task_item1.messages = [] - - mock_task_item2 = MagicMock() - mock_task_item2.item_id = "task_2" - mock_task_item2.user_id = "user_2" - mock_task_item2.mem_cube_id = "cube_2" - mock_task_item2.task_info = {"type": "answer"} - mock_task_item2.task_name = "test_task_2" - mock_task_item2.start_time = datetime.now() - mock_task_item2.end_time = None - mock_task_item2.status = "completed" - mock_task_item2.result = "success" - mock_task_item2.error_message = None - mock_task_item2.messages = ["message1", "message2"] - - with patch.object( - self.scheduler.dispatcher, - "get_running_tasks", - return_value={"task_1": mock_task_item1, "task_2": mock_task_item2}, - ) as mock_get_running_tasks: - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify result structure - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 2) - self.assertIn("task_1", result) - self.assertIn("task_2", result) - - # Verify task_1 details - task1_dict = result["task_1"] - self.assertEqual(task1_dict["item_id"], "task_1") - self.assertEqual(task1_dict["user_id"], "user_1") - self.assertEqual(task1_dict["status"], "running") - - # Verify task_2 details - task2_dict = result["task_2"] - self.assertEqual(task2_dict["item_id"], "task_2") - self.assertEqual(task2_dict["user_id"], "user_2") - self.assertEqual(task2_dict["status"], "completed") - self.assertEqual(task2_dict["result"], "success") - self.assertEqual(task2_dict["messages"], ["message1", "message2"]) - - # Verify dispatcher method was called - mock_get_running_tasks.assert_called_once_with(filter_func=None)