diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py
new file mode 100644
index 000000000..11f0ebb81
--- /dev/null
+++ b/examples/mem_scheduler/api_w_scheduler.py
@@ -0,0 +1,62 @@
+from memos.api.routers.server_router import mem_scheduler
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+# Debug: Print scheduler configuration
+print("=== Scheduler Configuration Debug ===")
+print(f"Scheduler type: {type(mem_scheduler).__name__}")
+print(f"Config: {mem_scheduler.config}")
+print(f"use_redis_queue: {mem_scheduler.use_redis_queue}")
+print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}")
+print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}")
+
+# Check if Redis queue is connected
+if hasattr(mem_scheduler.memos_message_queue, "_is_connected"):
+ print(f"Redis connected: {mem_scheduler.memos_message_queue._is_connected}")
+if hasattr(mem_scheduler.memos_message_queue, "_redis_conn"):
+ print(f"Redis connection: {mem_scheduler.memos_message_queue._redis_conn}")
+print("=====================================\n")
+
+queue = mem_scheduler.memos_message_queue
+queue.clear()
+
+
+# 1. Define a handler function
+def my_test_handler(messages: list[ScheduleMessageItem]):
+ print(f"My test handler received {len(messages)} messages:")
+ for msg in messages:
+ print(f" my_test_handler - {msg.item_id}: {msg.content}")
+ print(
+ f"{queue._redis_conn.xinfo_groups(queue.stream_name)} qsize: {queue.qsize()} messages:{messages}"
+ )
+
+
+# 2. Register the handler
+TEST_HANDLER_LABEL = "test_handler"
+mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler})
+
+# 3. Create messages
+messages_to_send = [
+ ScheduleMessageItem(
+ item_id=f"test_item_{i}",
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ label=TEST_HANDLER_LABEL,
+ content=f"This is test message {i}",
+ )
+ for i in range(5)
+]
+
+# 5. Submit messages
+for mes in messages_to_send:
+ print(f"Submitting message {mes.item_id} to the scheduler...")
+ mem_scheduler.submit_messages([mes])
+
+# 6. Wait for messages to be processed (limited to 100 checks)
+print("Waiting for messages to be consumed (max 100 checks)...")
+mem_scheduler.mem_scheduler_wait()
+
+
+# 7. Stop the scheduler
+print("Stopping the scheduler...")
+mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py
deleted file mode 100644
index 664168f62..000000000
--- a/examples/mem_scheduler/memos_w_optimized_scheduler.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import shutil
-import sys
-
-from pathlib import Path
-
-from memos_w_scheduler import init_task, show_web_logs
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.configs.mem_scheduler import AuthConfig
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-def run_with_scheduler_init():
- print("==== run_with_automatic_scheduler_init ====")
- conversations, questions = init_task()
-
- # set configs
- mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
- )
-
- mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml"
- )
-
- # default local graphdb uri
- if AuthConfig.default_config_exists():
- auth_config = AuthConfig.from_local_config()
-
- mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
- mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
-
- mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
- mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
- mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
- mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
- mem_cube_config.text_mem.config.graph_db.config.auto_create = (
- auth_config.graph_db.auto_create
- )
-
- # Initialization
- mos = MOS(mos_config)
-
- user_id = "user_1"
- mos.create_user(user_id)
-
- mem_cube_id = "mem_cube_5"
- mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
-
- if Path(mem_cube_name_or_path).exists():
- shutil.rmtree(mem_cube_name_or_path)
- print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
-
- mem_cube = GeneralMemCube(mem_cube_config)
- mem_cube.dump(mem_cube_name_or_path)
- mos.register_mem_cube(
- mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
- )
-
- mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
-
- for item in questions:
- print("===== Chat Start =====")
- query = item["question"]
- print(f"Query:\n {query}\n")
- response = mos.chat(query=query, user_id=user_id)
- print(f"Answer:\n {response}\n")
-
- show_web_logs(mem_scheduler=mos.mem_scheduler)
-
- mos.mem_scheduler.stop()
-
-
-if __name__ == "__main__":
- run_with_scheduler_init()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
deleted file mode 100644
index ed4f721ad..000000000
--- a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import json
-import shutil
-import sys
-
-from pathlib import Path
-
-from memos_w_scheduler_for_test import init_task
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.configs.mem_scheduler import AuthConfig
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR))
-
-# Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-if __name__ == "__main__":
- # set up data
- conversations, questions = init_task()
-
- # set configs
- mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
- )
-
- mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml"
- )
-
- # default local graphdb uri
- if AuthConfig.default_config_exists():
- auth_config = AuthConfig.from_local_config()
-
- mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
- mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
-
- mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
- mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
- mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
- mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
- mem_cube_config.text_mem.config.graph_db.config.auto_create = (
- auth_config.graph_db.auto_create
- )
-
- # Initialization
- mos = MOSForTestScheduler(mos_config)
-
- user_id = "user_1"
- mos.create_user(user_id)
-
- mem_cube_id = "mem_cube_5"
- mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
-
- if Path(mem_cube_name_or_path).exists():
- shutil.rmtree(mem_cube_name_or_path)
- print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
-
- mem_cube = GeneralMemCube(mem_cube_config)
- mem_cube.dump(mem_cube_name_or_path)
- mos.register_mem_cube(
- mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
- )
-
- mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
-
- # Add interfering conversations
- file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json")
- scene_data = json.load(file_path.open("r", encoding="utf-8"))
- mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
- mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
-
- for item in questions:
- print("===== Chat Start =====")
- query = item["question"]
- print(f"Query:\n {query}\n")
- response = mos.chat(query=query, user_id=user_id)
- print(f"Answer:\n {response}\n")
-
- mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py
index dc196b85a..c523a8667 100644
--- a/examples/mem_scheduler/memos_w_scheduler.py
+++ b/examples/mem_scheduler/memos_w_scheduler.py
@@ -70,13 +70,48 @@ def init_task():
return conversations, questions
+def show_web_logs(mem_scheduler: GeneralScheduler):
+ """Display all web log entries from the scheduler's log queue.
+
+ Args:
+ mem_scheduler: The scheduler instance containing web logs to display
+ """
+ if mem_scheduler._web_log_message_queue.empty():
+ print("Web log queue is currently empty.")
+ return
+
+ print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50)
+
+ # Create a temporary queue to preserve the original queue contents
+ temp_queue = Queue()
+ log_count = 0
+
+ while not mem_scheduler._web_log_message_queue.empty():
+ log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get()
+ temp_queue.put(log_item)
+ log_count += 1
+
+ # Print log entry details
+ print(f"\nLog Entry #{log_count}:")
+ print(f'- "{log_item.label}" log: {log_item}')
+
+ print("-" * 50)
+
+ # Restore items back to the original queue
+ while not temp_queue.empty():
+ mem_scheduler._web_log_message_queue.put(temp_queue.get())
+
+ print(f"\nTotal {log_count} web log entries displayed.")
+ print("=" * 110 + "\n")
+
+
def run_with_scheduler_init():
print("==== run_with_automatic_scheduler_init ====")
conversations, questions = init_task()
# set configs
mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
)
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
@@ -118,6 +153,7 @@ def run_with_scheduler_init():
)
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
+ mos.mem_scheduler.current_mem_cube = mem_cube
for item in questions:
print("===== Chat Start =====")
@@ -131,40 +167,5 @@ def run_with_scheduler_init():
mos.mem_scheduler.stop()
-def show_web_logs(mem_scheduler: GeneralScheduler):
- """Display all web log entries from the scheduler's log queue.
-
- Args:
- mem_scheduler: The scheduler instance containing web logs to display
- """
- if mem_scheduler._web_log_message_queue.empty():
- print("Web log queue is currently empty.")
- return
-
- print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50)
-
- # Create a temporary queue to preserve the original queue contents
- temp_queue = Queue()
- log_count = 0
-
- while not mem_scheduler._web_log_message_queue.empty():
- log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get()
- temp_queue.put(log_item)
- log_count += 1
-
- # Print log entry details
- print(f"\nLog Entry #{log_count}:")
- print(f'- "{log_item.label}" log: {log_item}')
-
- print("-" * 50)
-
- # Restore items back to the original queue
- while not temp_queue.empty():
- mem_scheduler._web_log_message_queue.put(temp_queue.get())
-
- print(f"\nTotal {log_count} web log entries displayed.")
- print("=" * 110 + "\n")
-
-
if __name__ == "__main__":
run_with_scheduler_init()
diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py
index 6faac98af..2e135f127 100644
--- a/examples/mem_scheduler/memos_w_scheduler_for_test.py
+++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py
@@ -1,10 +1,11 @@
import json
import shutil
import sys
-import time
from pathlib import Path
+from memos_w_scheduler import init_task
+
from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
from memos.configs.mem_scheduler import AuthConfig
@@ -15,155 +16,19 @@
FILE_PATH = Path(__file__).absolute()
BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-def display_memory_cube_stats(mos, user_id, mem_cube_id):
- """Display detailed memory cube statistics."""
- print(f"\n๐ MEMORY CUBE STATISTICS for {mem_cube_id}:")
- print("-" * 60)
-
- mem_cube = mos.mem_cubes.get(mem_cube_id)
- if not mem_cube:
- print(" โ Memory cube not found")
- return
-
- # Text memory stats
- if mem_cube.text_mem:
- text_mem = mem_cube.text_mem
- working_memories = text_mem.get_working_memory()
- all_memories = text_mem.get_all()
-
- print(" ๐ Text Memory:")
- print(f" โข Working Memory Items: {len(working_memories)}")
- print(
- f" โข Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}"
- )
-
- if working_memories:
- print(" โข Working Memory Content Preview:")
- for i, mem in enumerate(working_memories[:2]):
- content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory
- print(f" {i + 1}. {content}")
-
- # Activation memory stats
- if mem_cube.act_mem:
- act_mem = mem_cube.act_mem
- act_memories = list(act_mem.get_all())
- print(" โก Activation Memory:")
- print(f" โข KV Cache Items: {len(act_memories)}")
- if act_memories:
- print(
- f" โข Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}"
- )
-
- print("-" * 60)
-
-
-def display_scheduler_status(mos):
- """Display current scheduler status and configuration."""
- print("\nโ๏ธ SCHEDULER STATUS:")
- print("-" * 60)
-
- if not mos.mem_scheduler:
- print(" โ Memory scheduler not initialized")
- return
-
- scheduler = mos.mem_scheduler
- print(f" ๐ Scheduler Running: {scheduler._running}")
- print(f" ๐ Internal Queue Size: {scheduler.memos_message_queue.qsize()}")
- print(f" ๐งต Parallel Dispatch: {scheduler.enable_parallel_dispatch}")
- print(f" ๐ฅ Max Workers: {scheduler.thread_pool_max_workers}")
- print(f" โฑ๏ธ Consume Interval: {scheduler._consume_interval}s")
-
- if scheduler.monitor:
- print(" ๐ Monitor Active: โ
")
- print(f" ๐๏ธ Database Engine: {'โ
' if scheduler.db_engine else 'โ'}")
-
- if scheduler.dispatcher:
- print(" ๐ Dispatcher Active: โ
")
- print(
- f" ๐ง Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}"
- )
+sys.path.insert(0, str(BASE_DIR))
- print("-" * 60)
-
-
-def init_task():
- conversations = [
- {
- "role": "user",
- "content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.",
- },
- {"role": "assistant", "content": "Great! Any special care for them?"},
- {
- "role": "user",
- "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.",
- },
- {
- "role": "user",
- "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.",
- },
- {
- "role": "user",
- "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.",
- },
- ]
-
- questions = [
- # 1. Basic factual recall (simple)
- {
- "question": "What breed is Max?",
- "category": "Pet",
- "expected": "golden retriever",
- "difficulty": "easy",
- },
- # 2. Temporal context (medium)
- {
- "question": "Where will I live next month?",
- "category": "Location",
- "expected": "Chicago",
- "difficulty": "medium",
- },
- # 3. Information correction (hard)
- {
- "question": "How old is Bella really?",
- "category": "Pet",
- "expected": "6",
- "difficulty": "hard",
- "hint": "User corrected the age later",
- },
- # 4. Relationship inference (harder)
- {
- "question": "Why might Whiskers be nervous around my pets?",
- "category": "Behavior",
- "expected": "Bella chases her sometimes",
- "difficulty": "harder",
- },
- # 5. Combined medical info (hardest)
- {
- "question": "Which pets have health considerations?",
- "category": "Health",
- "expected": "Max needs joint supplements, Bella is allergic to chicken",
- "difficulty": "hardest",
- "requires": ["combining multiple facts", "ignoring outdated info"],
- },
- ]
- return conversations, questions
+# Enable execution from any working directory
+logger = get_logger(__name__)
if __name__ == "__main__":
- print("๐ Starting Enhanced Memory Scheduler Test...")
- print("=" * 80)
-
# set up data
conversations, questions = init_task()
# set configs
mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
)
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
@@ -186,7 +51,6 @@ def init_task():
)
# Initialization
- print("๐ง Initializing MOS with Scheduler...")
mos = MOSForTestScheduler(mos_config)
user_id = "user_1"
@@ -197,15 +61,15 @@ def init_task():
if Path(mem_cube_name_or_path).exists():
shutil.rmtree(mem_cube_name_or_path)
- print(f"๐๏ธ {mem_cube_name_or_path} is not empty, and has been removed.")
+ print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
mem_cube = GeneralMemCube(mem_cube_config)
mem_cube.dump(mem_cube_name_or_path)
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
+ mos.mem_scheduler.current_mem_cube = mem_cube
- print("๐ Adding initial conversations...")
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
# Add interfering conversations
@@ -214,77 +78,11 @@ def init_task():
mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
- # Display initial status
- print("\n๐ INITIAL SYSTEM STATUS:")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- # Process questions with enhanced monitoring
- print(f"\n๐ฏ Starting Question Processing ({len(questions)} questions)...")
- question_start_time = time.time()
-
- for i, item in enumerate(questions, 1):
- print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}")
- print(f"๐ Category: {item['category']} | Difficulty: {item['difficulty']}")
- print(f"๐ฏ Expected: {item['expected']}")
- if "hint" in item:
- print(f"๐ก Hint: {item['hint']}")
- if "requires" in item:
- print(f"๐ Requires: {', '.join(item['requires'])}")
-
- print(f"\n๐ Processing Query: {item['question']}")
- query_start_time = time.time()
-
- response = mos.chat(query=item["question"], user_id=user_id)
-
- query_time = time.time() - query_start_time
- print(f"โฑ๏ธ Query Processing Time: {query_time:.3f}s")
- print(f"๐ค Response: {response}")
-
- # Display intermediate status every 2 questions
- if i % 2 == 0:
- print(f"\n๐ INTERMEDIATE STATUS (Question {i}):")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- total_processing_time = time.time() - question_start_time
- print(f"\nโฑ๏ธ Total Question Processing Time: {total_processing_time:.3f}s")
-
- # Display final scheduler performance summary
- print("\n" + "=" * 80)
- print("๐ FINAL SCHEDULER PERFORMANCE SUMMARY")
- print("=" * 80)
-
- summary = mos.get_scheduler_summary()
- print(f"๐ข Total Queries Processed: {summary['total_queries']}")
- print(f"โก Total Scheduler Calls: {summary['total_scheduler_calls']}")
- print(f"โฑ๏ธ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s")
- print(f"๐ง Memory Optimizations Applied: {summary['memory_optimization_count']}")
- print(f"๐ Working Memory Updates: {summary['working_memory_updates']}")
- print(f"โก Activation Memory Updates: {summary['activation_memory_updates']}")
- print(f"๐ Average Query Processing Time: {summary['average_query_processing_time']:.3f}s")
-
- # Performance insights
- print("\n๐ก PERFORMANCE INSIGHTS:")
- if summary["total_scheduler_calls"] > 0:
- optimization_rate = (
- summary["memory_optimization_count"] / summary["total_scheduler_calls"]
- ) * 100
- print(f" โข Memory Optimization Rate: {optimization_rate:.1f}%")
-
- if summary["average_scheduler_response_time"] < 0.1:
- print(" โข Scheduler Performance: ๐ข Excellent (< 100ms)")
- elif summary["average_scheduler_response_time"] < 0.5:
- print(" โข Scheduler Performance: ๐ก Good (100-500ms)")
- else:
- print(" โข Scheduler Performance: ๐ด Needs Improvement (> 500ms)")
-
- # Final system status
- print("\n๐ FINAL SYSTEM STATUS:")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- print("=" * 80)
- print("๐ Test completed successfully!")
+ for item in questions:
+ print("===== Chat Start =====")
+ query = item["question"]
+ print(f"Query:\n {query}\n")
+ response = mos.chat(query=query, user_id=user_id)
+ print(f"Answer:\n {response}\n")
mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py
deleted file mode 100644
index bbb57b4ab..000000000
--- a/examples/mem_scheduler/orm_examples.py
+++ /dev/null
@@ -1,374 +0,0 @@
-#!/usr/bin/env python3
-"""
-ORM Examples for MemScheduler
-
-This script demonstrates how to use the BaseDBManager's new environment variable loading methods
-for MySQL and Redis connections.
-"""
-
-import multiprocessing
-import os
-import sys
-
-from pathlib import Path
-
-
-# Add the src directory to the Python path
-sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
-
-from memos.log import get_logger
-from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError
-from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager
-
-
-logger = get_logger(__name__)
-
-
-def test_mysql_engine_from_env():
- """Test loading MySQL engine from environment variables"""
- print("\n" + "=" * 60)
- print("Testing MySQL Engine from Environment Variables")
- print("=" * 60)
-
- try:
- # Test loading MySQL engine from current environment variables
- mysql_engine = BaseDBManager.load_mysql_engine_from_env()
- if mysql_engine is None:
- print("โ Failed to create MySQL engine - check environment variables")
- return
-
- print(f"โ
Successfully created MySQL engine: {mysql_engine}")
- print(f" Engine URL: {mysql_engine.url}")
-
- # Test connection
- with mysql_engine.connect() as conn:
- from sqlalchemy import text
-
- result = conn.execute(text("SELECT 'MySQL connection test successful' as message"))
- message = result.fetchone()[0]
- print(f" Connection test: {message}")
-
- mysql_engine.dispose()
- print(" MySQL engine disposed successfully")
-
- except DatabaseError as e:
- print(f"โ DatabaseError: {e}")
- except Exception as e:
- print(f"โ Unexpected error: {e}")
-
-
-def test_redis_connection_from_env():
- """Test loading Redis connection from environment variables"""
- print("\n" + "=" * 60)
- print("Testing Redis Connection from Environment Variables")
- print("=" * 60)
-
- try:
- # Test loading Redis connection from current environment variables
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("โ Failed to create Redis connection - check environment variables")
- return
-
- print(f"โ
Successfully created Redis connection: {redis_client}")
-
- # Test basic Redis operations
- redis_client.set("test_key", "Hello from ORM Examples!")
- value = redis_client.get("test_key")
- print(f" Redis test - Set/Get: {value}")
-
- # Test Redis info
- info = redis_client.info("server")
- redis_version = info.get("redis_version", "unknown")
- print(f" Redis server version: {redis_version}")
-
- # Clean up test key
- redis_client.delete("test_key")
- print(" Test key cleaned up")
-
- redis_client.close()
- print(" Redis connection closed successfully")
-
- except DatabaseError as e:
- print(f"โ DatabaseError: {e}")
- except Exception as e:
- print(f"โ Unexpected error: {e}")
-
-
-def test_environment_variables():
- """Test and display current environment variables"""
- print("\n" + "=" * 60)
- print("Current Environment Variables")
- print("=" * 60)
-
- # MySQL environment variables
- mysql_vars = [
- "MYSQL_HOST",
- "MYSQL_PORT",
- "MYSQL_USERNAME",
- "MYSQL_PASSWORD",
- "MYSQL_DATABASE",
- "MYSQL_CHARSET",
- ]
-
- print("\nMySQL Environment Variables:")
- for var in mysql_vars:
- value = os.getenv(var, "Not set")
- # Mask password for security
- if "PASSWORD" in var and value != "Not set":
- value = "*" * len(value)
- print(f" {var}: {value}")
-
- # Redis environment variables
- redis_vars = [
- "REDIS_HOST",
- "REDIS_PORT",
- "REDIS_DB",
- "REDIS_PASSWORD",
- "MEMSCHEDULER_REDIS_HOST",
- "MEMSCHEDULER_REDIS_PORT",
- "MEMSCHEDULER_REDIS_DB",
- "MEMSCHEDULER_REDIS_PASSWORD",
- ]
-
- print("\nRedis Environment Variables:")
- for var in redis_vars:
- value = os.getenv(var, "Not set")
- # Mask password for security
- if "PASSWORD" in var and value != "Not set":
- value = "*" * len(value)
- print(f" {var}: {value}")
-
-
-def test_manual_env_loading():
- """Test loading environment variables manually from .env file"""
- print("\n" + "=" * 60)
- print("Testing Manual Environment Loading")
- print("=" * 60)
-
- env_file_path = "/Users/travistang/Documents/codes/memos/.env"
-
- if not os.path.exists(env_file_path):
- print(f"โ Environment file not found: {env_file_path}")
- return
-
- try:
- from dotenv import load_dotenv
-
- # Load environment variables
- load_dotenv(env_file_path)
- print(f"โ
Successfully loaded environment variables from {env_file_path}")
-
- # Test some key variables
- test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"]
- for var in test_vars:
- value = os.getenv(var, "Not set")
- if "KEY" in var and value != "Not set":
- value = f"{value[:10]}..." if len(value) > 10 else value
- print(f" {var}: {value}")
-
- except ImportError:
- print("โ python-dotenv not installed. Install with: pip install python-dotenv")
- except Exception as e:
- print(f"โ Error loading environment file: {e}")
-
-
-def test_redis_lockable_orm_with_list():
- """Test RedisDBManager with list[str] type synchronization"""
- print("\n" + "=" * 60)
- print("Testing RedisDBManager with list[str]")
- print("=" * 60)
-
- try:
- from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager
-
- # Create a simple list manager instance
- list_manager = SimpleListManager(["apple", "banana", "cherry"])
- print(f"Original list manager: {list_manager}")
-
- # Create RedisDBManager instance
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("โ Failed to create Redis connection - check environment variables")
- return
-
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="test_list_cube",
- obj=list_manager,
- )
-
- # Save to Redis
- db_manager.save_to_db(list_manager)
- print("โ
List manager saved to Redis")
-
- # Load from Redis
- loaded_manager = db_manager.load_from_db()
- if loaded_manager:
- print(f"Loaded list manager: {loaded_manager}")
- print(f"Items match: {list_manager.items == loaded_manager.items}")
- else:
- print("โ Failed to load list manager from Redis")
-
- # Clean up
- redis_client.delete("lockable_orm:test_user:test_list_cube:data")
- redis_client.delete("lockable_orm:test_user:test_list_cube:lock")
- redis_client.delete("lockable_orm:test_user:test_list_cube:version")
- redis_client.close()
-
- except Exception as e:
- print(f"โ Error in RedisDBManager test: {e}")
-
-
-def modify_list_process(process_id: int, items_to_add: list[str]):
- """Function to be run in separate processes to modify the list using merge_items"""
- try:
- from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager
-
- # Create Redis connection
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print(f"Process {process_id}: Failed to create Redis connection")
- return
-
- # Create a temporary list manager for this process with items to add
- temp_manager = SimpleListManager()
-
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=temp_manager,
- )
-
- print(f"Process {process_id}: Starting modification with items: {items_to_add}")
- for item in items_to_add:
- db_manager.obj.add_item(item)
- # Use sync_with_orm which internally uses merge_items
- db_manager.sync_with_orm(size_limit=None)
-
- print(f"Process {process_id}: Successfully synchronized with Redis")
-
- redis_client.close()
-
- except Exception as e:
- print(f"Process {process_id}: Error - {e}")
- import traceback
-
- traceback.print_exc()
-
-
-def test_multiprocess_synchronization():
- """Test multiprocess synchronization with RedisDBManager"""
- print("\n" + "=" * 60)
- print("Testing Multiprocess Synchronization")
- print("=" * 60)
-
- try:
- # Initialize Redis with empty list
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("โ Failed to create Redis connection")
- return
-
- # Initialize with empty list
- initial_manager = SimpleListManager([])
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=initial_manager,
- )
- db_manager.save_to_db(initial_manager)
- print("โ
Initialized empty list manager in Redis")
-
- # Define items for each process to add
- process_items = [
- ["item1", "item2"],
- ["item3", "item4"],
- ["item5", "item6"],
- ["item1", "item7"], # item1 is duplicate, should not be added twice
- ]
-
- # Create and start processes
- processes = []
- for i, items in enumerate(process_items):
- p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items))
- processes.append(p)
- p.start()
-
- # Wait for all processes to complete
- for p in processes:
- p.join()
-
- print("\n" + "-" * 40)
- print("All processes completed. Checking final result...")
-
- # Load final result
- final_db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=SimpleListManager([]),
- )
- final_manager = final_db_manager.load_from_db()
-
- if final_manager:
- print(f"Final synchronized list manager: {final_manager}")
- print(f"Final list length: {len(final_manager)}")
- print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}")
- print(f"Actual items: {set(final_manager.items)}")
-
- # Check if all unique items are present
- expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"}
- actual_items = set(final_manager.items)
-
- if expected_items == actual_items:
- print("โ
All processes contributed correctly - synchronization successful!")
- else:
- print(f"โ Expected items: {expected_items}")
- print(f" Actual items: {actual_items}")
- else:
- print("โ Failed to load final result")
-
- # Clean up
- redis_client.delete("lockable_orm:test_user:multiprocess_list:data")
- redis_client.delete("lockable_orm:test_user:multiprocess_list:lock")
- redis_client.delete("lockable_orm:test_user:multiprocess_list:version")
- redis_client.close()
-
- except Exception as e:
- print(f"โ Error in multiprocess synchronization test: {e}")
-
-
-def main():
- """Main function to run all tests"""
- print("ORM Examples - Environment Variable Loading Tests")
- print("=" * 80)
-
- # Test environment variables display
- test_environment_variables()
-
- # Test manual environment loading
- test_manual_env_loading()
-
- # Test MySQL engine loading
- test_mysql_engine_from_env()
-
- # Test Redis connection loading
- test_redis_connection_from_env()
-
- # Test RedisLockableORM with list[str]
- test_redis_lockable_orm_with_list()
-
- # Test multiprocess synchronization
- test_multiprocess_synchronization()
-
- print("\n" + "=" * 80)
- print("All tests completed!")
- print("=" * 80)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py
index 1660d6c02..2c3801539 100644
--- a/examples/mem_scheduler/redis_example.py
+++ b/examples/mem_scheduler/redis_example.py
@@ -22,7 +22,7 @@
sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-async def service_run():
+def service_run():
# Init
example_scheduler_config_path = (
f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml"
@@ -60,11 +60,11 @@ async def service_run():
content=query,
timestamp=datetime.now(),
)
- res = await mem_scheduler.redis_add_message_stream(message=message_item.to_dict())
+ res = mem_scheduler.redis_add_message_stream(message=message_item.to_dict())
print(
f"Added: {res}",
)
- await asyncio.sleep(0.5)
+ asyncio.sleep(0.5)
mem_scheduler.redis_stop_listening()
@@ -72,4 +72,4 @@ async def service_run():
if __name__ == "__main__":
- asyncio.run(service_run())
+ service_run()
diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py
index de99f1c95..4aedac711 100644
--- a/examples/mem_scheduler/try_schedule_modules.py
+++ b/examples/mem_scheduler/try_schedule_modules.py
@@ -176,6 +176,7 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
+ mos.mem_scheduler.current_mem_cube = mem_cube
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index f02edaad6..a276fa63d 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -175,7 +175,7 @@ def start_config_watch(cls):
@classmethod
def start_watch_if_enabled(cls) -> None:
enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true"
- print("enable:", enable)
+ logger.info(f"NACOS_ENABLE_WATCH: {enable}")
if not enable:
return
interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60"))
@@ -623,7 +623,10 @@ def get_scheduler_config() -> dict[str, Any]:
"MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true"
).lower()
== "true",
- "enable_activation_memory": True,
+ "enable_activation_memory": os.getenv(
+ "MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY", "false"
+ ).lower()
+ == "true",
},
}
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index 0412754c3..3b1ce2fc9 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -171,7 +171,9 @@ class APISearchRequest(BaseRequest):
query: str = Field(..., description="Search query")
user_id: str = Field(None, description="User ID")
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
- mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
+ mode: SearchMode = Field(
+ SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture"
+ )
internet_search: bool = Field(False, description="Whether to use internet search")
moscube: bool = Field(False, description="Whether to use MemOSCube")
top_k: int = Field(10, description="Number of results to return")
diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py
index 8df383bfb..ad43a07e4 100644
--- a/src/memos/api/routers/server_router.py
+++ b/src/memos/api/routers/server_router.py
@@ -256,12 +256,14 @@ def init_server():
db_engine=BaseDBManager.create_default_sqlite_engine(),
mem_reader=mem_reader,
)
- mem_scheduler.current_mem_cube = naive_mem_cube
- mem_scheduler.start()
+ mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube)
# Initialize SchedulerAPIModule
api_module = mem_scheduler.api_module
+ if os.getenv("API_SCHEDULER_ON", True):
+ mem_scheduler.start()
+
return (
graph_db,
mem_reader,
@@ -357,8 +359,10 @@ def search_memories(search_req: APISearchRequest):
"pref_mem": [],
"pref_note": "",
}
-
- search_mode = search_req.mode
+ if search_req.mode == SearchMode.NOT_INITIALIZED:
+ search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST)
+ else:
+ search_mode = search_req.mode
def _search_text():
if search_mode == SearchMode.FAST:
@@ -444,22 +448,38 @@ def fine_search_memories(
target_session_id = "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
- # Create MemCube and perform search
- search_results = naive_mem_cube.text_mem.search(
+ searcher = mem_scheduler.searcher
+
+ info = {
+ "user_id": search_req.user_id,
+ "session_id": target_session_id,
+ "chat_history": search_req.chat_history,
+ }
+
+ fast_retrieved_memories = searcher.retrieve(
query=search_req.query,
user_name=user_context.mem_cube_id,
top_k=search_req.top_k,
- mode=SearchMode.FINE,
+ mode=SearchMode.FAST,
manual_close_internet=not search_req.internet_search,
moscube=search_req.moscube,
search_filter=search_filter,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
+ info=info,
)
- formatted_memories = [_format_memory_item(data) for data in search_results]
+
+ fast_memories = searcher.post_retrieve(
+ retrieved_results=fast_retrieved_memories,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+
+ enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query(
+ query_history=[search_req.query],
+ memories=fast_memories,
+ )
+
+ formatted_memories = [_format_memory_item(data) for data in enhanced_results]
return formatted_memories
diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py
index e757f243b..afdaf6871 100644
--- a/src/memos/configs/mem_scheduler.py
+++ b/src/memos/configs/mem_scheduler.py
@@ -12,10 +12,13 @@
BASE_DIR,
DEFAULT_ACT_MEM_DUMP_PATH,
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT,
+ DEFAULT_CONSUME_BATCH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
DEFAULT_MULTI_TASK_RUNNING_TIMEOUT,
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
DEFAULT_USE_REDIS_QUEUE,
@@ -43,6 +46,11 @@ class BaseSchedulerConfig(BaseConfig):
gt=0,
description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})",
)
+ consume_batch: int = Field(
+ default=DEFAULT_CONSUME_BATCH,
+ gt=0,
+ description=f"Number of messages to consume in each batch (default: {DEFAULT_CONSUME_BATCH})",
+ )
auth_config_path: str | None = Field(
default=None,
description="Path to the authentication configuration file containing private credentials",
@@ -91,6 +99,17 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
description="Capacity of the activation memory monitor",
)
+ # Memory enhancement concurrency & retries configuration
+ enhance_batch_size: int | None = Field(
+ default=DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ description="Batch size for concurrent memory enhancement; None or <=1 disables batching",
+ )
+ enhance_retries: int = Field(
+ default=DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+ ge=0,
+ description="Number of retry attempts per enhancement batch",
+ )
+
# Database configuration for ORM persistence
db_path: str | None = Field(
default=None,
diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py
index 97ff9879f..1b6d4e126 100644
--- a/src/memos/mem_os/core.py
+++ b/src/memos/mem_os/core.py
@@ -283,7 +283,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=QUERY_LABEL,
content=query,
timestamp=datetime.utcnow(),
@@ -344,7 +343,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=response,
timestamp=datetime.utcnow(),
@@ -768,12 +766,10 @@ def process_textual_memory():
)
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "async":
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=MEM_READ_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
@@ -783,7 +779,6 @@ def process_textual_memory():
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
@@ -797,7 +792,6 @@ def process_preference_memory():
and self.mem_cubes[mem_cube_id].pref_mem
):
messages_list = [messages]
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "sync":
pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory(
messages_list,
@@ -816,7 +810,6 @@ def process_preference_memory():
user_id=target_user_id,
session_id=target_session_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=PREF_ADD_LABEL,
content=json.dumps(messages_list),
timestamp=datetime.utcnow(),
@@ -867,12 +860,10 @@ def process_preference_memory():
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "async":
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=MEM_READ_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
@@ -882,7 +873,6 @@ def process_preference_memory():
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
@@ -909,11 +899,9 @@ def process_preference_memory():
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py
index 6fc64c5e3..0114fc0da 100644
--- a/src/memos/mem_os/main.py
+++ b/src/memos/mem_os/main.py
@@ -205,7 +205,6 @@ def _chat_with_cot_enhancement(
# Step 7: Submit message to scheduler (same as core method)
if len(accessible_cubes) == 1:
mem_cube_id = accessible_cubes[0].cube_id
- mem_cube = self.mem_cubes[mem_cube_id]
if self.enable_mem_scheduler and self.mem_scheduler is not None:
from datetime import datetime
@@ -217,7 +216,6 @@ def _chat_with_cot_enhancement(
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=enhanced_response,
timestamp=datetime.now().isoformat(),
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
index 89e468bd7..cea8c89af 100644
--- a/src/memos/mem_os/product.py
+++ b/src/memos/mem_os/product.py
@@ -609,7 +609,6 @@ def _send_message_to_scheduler(
message_item = ScheduleMessageItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cubes[mem_cube_id],
label=label,
content=query,
timestamp=datetime.utcnow(),
diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py
index 28ca182e5..085025b7f 100644
--- a/src/memos/mem_scheduler/analyzer/api_analyzer.py
+++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py
@@ -7,7 +7,6 @@
import http.client
import json
-import time
from typing import Any
from urllib.parse import urlparse
@@ -15,6 +14,7 @@
import requests
from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import SearchMode
logger = get_logger(__name__)
@@ -487,7 +487,7 @@ def search_in_conversation(self, query, mode="fast", top_k=10, include_history=T
return result
- def test_continuous_conversation(self):
+ def test_continuous_conversation(self, mode=SearchMode.MIXTURE):
"""Test continuous conversation functionality"""
print("=" * 80)
print("Testing Continuous Conversation Functionality")
@@ -542,15 +542,15 @@ def test_continuous_conversation(self):
# Search for trip-related information
self.search_in_conversation(
- query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5
+ query="New Year's Eve Shanghai recommendations", mode=mode, top_k=5
)
# Search for food-related information
- self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3)
+ self.search_in_conversation(query="budget food Shanghai", mode=mode, top_k=3)
# Search without conversation history
self.search_in_conversation(
- query="Shanghai travel", mode="mixture", top_k=3, include_history=False
+ query="Shanghai travel", mode=mode, top_k=3, include_history=False
)
print("\nโ
Continuous conversation test completed successfully!")
@@ -645,7 +645,7 @@ def create_test_add_request(
operation=None,
)
- def run_all_tests(self):
+ def run_all_tests(self, mode=SearchMode.MIXTURE):
"""Run all available tests"""
print("๐ Starting comprehensive test suite")
print("=" * 80)
@@ -653,8 +653,7 @@ def run_all_tests(self):
# Test continuous conversation functionality
print("\n๐ฌ Testing CONTINUOUS CONVERSATION functions:")
try:
- self.test_continuous_conversation()
- time.sleep(5)
+ self.test_continuous_conversation(mode=mode)
print("โ
Continuous conversation test completed successfully")
except Exception as e:
print(f"โ Continuous conversation test failed: {e}")
@@ -682,7 +681,7 @@ def run_all_tests(self):
print("Using direct test mode")
try:
direct_analyzer = DirectSearchMemoriesAnalyzer()
- direct_analyzer.run_all_tests()
+ direct_analyzer.run_all_tests(mode=SearchMode.MIXTURE)
except Exception as e:
print(f"Direct test mode failed: {e}")
import traceback
diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py
new file mode 100644
index 000000000..d37e17456
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py
@@ -0,0 +1,1322 @@
+"""
+Evaluation Analyzer for Bad Cases
+
+This module provides the EvalAnalyzer class that extracts bad cases from evaluation results
+and analyzes whether memories contain sufficient information to answer golden answers.
+"""
+
+import json
+import os
+import sys
+
+from pathlib import Path
+from typing import Any
+
+from openai import OpenAI
+
+from memos.api.routers.server_router import mem_scheduler
+from memos.log import get_logger
+from memos.memories.textual.item import TextualMemoryMetadata
+from memos.memories.textual.tree import TextualMemoryItem
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent # Go up to project root
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+class EvalAnalyzer:
+ """
+ Evaluation Analyzer class for extracting and analyzing bad cases.
+
+ This class extracts bad cases from evaluation results and uses LLM to analyze
+ whether memories contain sufficient information to answer golden answers.
+ """
+
+ def __init__(
+ self,
+ openai_api_key: str | None = None,
+ openai_base_url: str | None = None,
+ openai_model: str = "gpt-4o-mini",
+ output_dir: str = "./tmp/eval_analyzer",
+ ):
+ """
+ Initialize the EvalAnalyzer.
+
+ Args:
+ openai_api_key: OpenAI API key
+ openai_base_url: OpenAI base URL
+ openai_model: OpenAI model to use
+ output_dir: Output directory for results
+ """
+ self.output_dir = Path(output_dir)
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Initialize OpenAI client
+ self.openai_client = OpenAI(
+ api_key=openai_api_key or os.getenv("MEMSCHEDULER_OPENAI_API_KEY"),
+ base_url=openai_base_url or os.getenv("MEMSCHEDULER_OPENAI_BASE_URL"),
+ )
+ self.openai_model = openai_model or os.getenv(
+ "MEMSCHEDULER_OPENAI_DEFAULT_MODEL", "gpt-4o-mini"
+ )
+
+ logger.info(f"EvalAnalyzer initialized with model: {self.openai_model}")
+
+ def load_json_file(self, filepath: str) -> Any:
+ """Load JSON file safely."""
+ try:
+ with open(filepath, encoding="utf-8") as f:
+ return json.load(f)
+ except FileNotFoundError:
+ logger.error(f"File not found: {filepath}")
+ return None
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON decode error in {filepath}: {e}")
+ return None
+
+ def extract_bad_cases(self, judged_file: str, search_results_file: str) -> list[dict[str, Any]]:
+ """
+ Extract bad cases from judged results and corresponding search results.
+
+ Args:
+ judged_file: Path to the judged results JSON file
+ search_results_file: Path to the search results JSON file
+
+ Returns:
+ List of bad cases with their memories
+ """
+ logger.info(f"Loading judged results from: {judged_file}")
+ judged_data = self.load_json_file(judged_file)
+ if not judged_data:
+ return []
+
+ logger.info(f"Loading search results from: {search_results_file}")
+ search_data = self.load_json_file(search_results_file)
+ if not search_data:
+ return []
+
+ bad_cases = []
+
+ # Process each user's data
+ for user_id, user_judged_results in judged_data.items():
+ user_search_results = search_data.get(user_id, [])
+
+ # Create a mapping from query to search context
+ search_context_map = {}
+ for search_result in user_search_results:
+ query = search_result.get("query", "")
+ context = search_result.get("context", "")
+ search_context_map[query] = context
+
+ # Process each question for this user
+ for result in user_judged_results:
+ # Check if this is a bad case (all judgments are False)
+ judgments = result.get("llm_judgments", {})
+ is_bad_case = all(not judgment for judgment in judgments.values())
+
+ if is_bad_case:
+ question = result.get("question", "")
+ answer = result.get("answer", "")
+ golden_answer = result.get("golden_answer", "")
+
+ # Find corresponding memories from search results
+ memories = search_context_map.get(question, "")
+
+ bad_case = {
+ "user_id": user_id,
+ "query": question,
+ "answer": answer,
+ "golden_answer": golden_answer,
+ "memories": memories,
+ "category": result.get("category", 0),
+ "nlp_metrics": result.get("nlp_metrics", {}),
+ "response_duration_ms": result.get("response_duration_ms", 0),
+ "search_duration_ms": result.get("search_duration_ms", 0),
+ "total_duration_ms": result.get("total_duration_ms", 0),
+ }
+
+ bad_cases.append(bad_case)
+
+ logger.info(f"Extracted {len(bad_cases)} bad cases")
+ return bad_cases
+
+ def analyze_memory_sufficiency(
+ self, query: str, golden_answer: str, memories: str
+ ) -> dict[str, Any]:
+ """
+ Use LLM to analyze whether memories contain sufficient information to answer the golden answer.
+
+ Args:
+ query: The original query
+ golden_answer: The correct answer
+ memories: The memory context
+
+ Returns:
+ Analysis result containing sufficiency judgment and relevant memory indices
+ """
+ prompt = f"""
+You are an expert analyst tasked with determining whether the provided memories contain sufficient information to answer a specific question correctly.
+
+**Question:** {query}
+
+**Golden Answer (Correct Answer):** {golden_answer}
+
+**Available Memories:**
+{memories}
+
+**Task:**
+1. Analyze whether the memories contain enough information to derive the golden answer
+2. Identify which specific memory entries (if any) contain relevant information
+3. Provide a clear judgment: True if sufficient, False if insufficient
+
+**Response Format (JSON):**
+{{
+ "sufficient": true/false,
+ "confidence": 0.0-1.0,
+ "relevant_memories": ["memory_1", "memory_2", ...],
+ "reasoning": "Detailed explanation of your analysis",
+ "missing_information": "What key information is missing (if insufficient)"
+}}
+
+**Guidelines:**
+- Be strict in your evaluation - only mark as sufficient if the memories clearly contain the information needed
+- Consider both direct and indirect information that could lead to the golden answer
+- Pay attention to dates, names, events, and specific details
+- If information is ambiguous or requires significant inference, lean towards insufficient
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are a precise analyst who evaluates information sufficiency.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1000,
+ )
+
+ content = response.choices[0].message.content.strip()
+
+ # Try to parse JSON response
+ try:
+ # Remove markdown code blocks if present
+ if content.startswith("```json"):
+ content = content[7:]
+ if content.endswith("```"):
+ content = content[:-3]
+ content = content.strip()
+
+ analysis = json.loads(content)
+ return analysis
+
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse LLM response as JSON: {content}")
+ return {
+ "sufficient": False,
+ "confidence": 0.0,
+ "relevant_memories": [],
+ "reasoning": f"Failed to parse LLM response: {content}",
+ "missing_information": "Analysis failed",
+ }
+
+ except Exception as e:
+ logger.error(f"Error in LLM analysis: {e}")
+ return {
+ "sufficient": False,
+ "confidence": 0.0,
+ "relevant_memories": [],
+ "reasoning": f"Error occurred: {e!s}",
+ "missing_information": "Analysis failed due to error",
+ }
+
+ def process_memories_with_llm(
+ self, memories: str, query: str, processing_type: str = "summarize"
+ ) -> dict[str, Any]:
+ """
+ Use LLM to process memories for better question answering.
+
+ Args:
+ memories: The raw memory content
+ query: The query that will be answered using these memories
+ processing_type: Type of processing ("summarize", "restructure", "enhance")
+
+ Returns:
+ Dictionary containing processed memories and processing metadata
+ """
+ if processing_type == "summarize":
+ prompt = f"""
+You are an expert at summarizing and organizing information to help answer specific questions.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Summarize and organize the above memories in a way that would be most helpful for answering the target question. Focus on:
+1. Key facts and information relevant to the question
+2. Important relationships and connections
+3. Chronological or logical organization where applicable
+4. Remove redundant or irrelevant information
+
+**Processed Memories:**
+"""
+ elif processing_type == "restructure":
+ prompt = f"""
+You are an expert at restructuring information to optimize question answering.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Restructure the above memories into a clear, logical format that directly supports answering the target question. Organize by:
+1. Most relevant information first
+2. Supporting details and context
+3. Clear categorization of different types of information
+4. Logical flow that leads to the answer
+
+**Restructured Memories:**
+"""
+ elif processing_type == "enhance":
+ prompt = f"""
+You are an expert at enhancing information by adding context and making connections.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Enhance the above memories by:
+1. Making implicit connections explicit
+2. Adding relevant context that helps answer the question
+3. Highlighting key relationships between different pieces of information
+4. Organizing information in a question-focused manner
+
+**Enhanced Memories:**
+"""
+ else:
+ raise ValueError(f"Unknown processing_type: {processing_type}")
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert information processor who optimizes content for question answering.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.3,
+ max_tokens=2000,
+ )
+
+ processed_memories = response.choices[0].message.content.strip()
+
+ return {
+ "processed_memories": processed_memories,
+ "processing_type": processing_type,
+ "original_length": len(memories),
+ "processed_length": len(processed_memories),
+ "compression_ratio": len(processed_memories) / len(memories)
+ if len(memories) > 0
+ else 0,
+ }
+
+ except Exception as e:
+ logger.error(f"Error in memory processing: {e}")
+ return {
+ "processed_memories": memories, # Fallback to original
+ "processing_type": processing_type,
+ "original_length": len(memories),
+ "processed_length": len(memories),
+ "compression_ratio": 1.0,
+ "error": str(e),
+ }
+
+ def generate_answer_with_memories(
+ self, query: str, memories: str, memory_type: str = "original"
+ ) -> dict[str, Any]:
+ """
+ Generate an answer to the query using the provided memories.
+
+ Args:
+ query: The question to answer
+ memories: The memory content to use
+ memory_type: Type of memories ("original", "processed")
+
+ Returns:
+ Dictionary containing the generated answer and metadata
+ """
+ prompt = f"""
+ You are a knowledgeable and helpful AI assistant.
+
+ # CONTEXT:
+ You have access to memories from two speakers in a conversation. These memories contain
+ timestamped information that may be relevant to answering the question.
+
+ # INSTRUCTIONS:
+ 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer.
+ 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth.
+ 3. If the question asks about a specific event or fact, look for direct evidence in the memories.
+ 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description).
+ 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
+ 6. Always convert relative time references to specific dates, months, or years in your final answer.
+ 7. Do not confuse character names mentioned in memories with the actual users who created them.
+ 8. The answer must be brief (under 5-6 words) and direct, with no extra description.
+
+ # APPROACH (Think step by step):
+ 1. First, examine all memories that contain information related to the question.
+ 2. Synthesize findings from multiple memories if a single entry is insufficient.
+ 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events.
+ 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation.
+ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge).
+ 6. Double-check that your answer directly addresses the question asked and adheres to all instructions.
+ 7. Ensure your final answer is specific and avoids vague time references.
+
+ {memories}
+
+ Question: {query}
+
+ Answer:
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are a precise assistant who answers questions based only on provided information.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1000,
+ )
+
+ answer = response.choices[0].message.content.strip()
+
+ return {
+ "answer": answer,
+ "memory_type": memory_type,
+ "query": query,
+ "memory_length": len(memories),
+ "answer_length": len(answer),
+ }
+
+ except Exception as e:
+ logger.error(f"Error in answer generation: {e}")
+ return {
+ "answer": f"Error generating answer: {e!s}",
+ "memory_type": memory_type,
+ "query": query,
+ "memory_length": len(memories),
+ "answer_length": 0,
+ "error": str(e),
+ }
+
+ def compare_answer_quality(
+ self, query: str, golden_answer: str, original_answer: str, processed_answer: str
+ ) -> dict[str, Any]:
+ """
+ Compare the quality of answers generated from original vs processed memories.
+
+ Args:
+ query: The original query
+ golden_answer: The correct/expected answer
+ original_answer: Answer generated from original memories
+ processed_answer: Answer generated from processed memories
+
+ Returns:
+ Dictionary containing comparison results
+ """
+ prompt = f"""
+You are an expert evaluator comparing the quality of two answers against a golden standard.
+
+**Question:** {query}
+
+**Golden Answer (Correct):** {golden_answer}
+
+**Answer A (Original Memories):** {original_answer}
+
+**Answer B (Processed Memories):** {processed_answer}
+
+**Task:**
+Compare both answers against the golden answer and evaluate:
+1. Accuracy: How correct is each answer?
+2. Completeness: How complete is each answer?
+3. Relevance: How relevant is each answer to the question?
+4. Clarity: How clear and well-structured is each answer?
+
+**Response Format (JSON):**
+{{
+ "original_scores": {{
+ "accuracy": 0.0-1.0,
+ "completeness": 0.0-1.0,
+ "relevance": 0.0-1.0,
+ "clarity": 0.0-1.0,
+ "overall": 0.0-1.0
+ }},
+ "processed_scores": {{
+ "accuracy": 0.0-1.0,
+ "completeness": 0.0-1.0,
+ "relevance": 0.0-1.0,
+ "clarity": 0.0-1.0,
+ "overall": 0.0-1.0
+ }},
+ "winner": "original|processed|tie",
+ "improvement": 0.0-1.0,
+ "reasoning": "Detailed explanation of the comparison"
+}}
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert evaluator who compares answer quality objectively.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1500,
+ )
+
+ content = response.choices[0].message.content.strip()
+
+ # Try to parse JSON response
+ try:
+ if content.startswith("```json"):
+ content = content[7:]
+ if content.endswith("```"):
+ content = content[:-3]
+ content = content.strip()
+
+ comparison = json.loads(content)
+ return comparison
+
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse comparison response as JSON: {content}")
+ return {
+ "original_scores": {
+ "accuracy": 0.5,
+ "completeness": 0.5,
+ "relevance": 0.5,
+ "clarity": 0.5,
+ "overall": 0.5,
+ },
+ "processed_scores": {
+ "accuracy": 0.5,
+ "completeness": 0.5,
+ "relevance": 0.5,
+ "clarity": 0.5,
+ "overall": 0.5,
+ },
+ "winner": "tie",
+ "improvement": 0.0,
+ "reasoning": f"Failed to parse comparison: {content}",
+ }
+
+ except Exception as e:
+ logger.error(f"Error in answer comparison: {e}")
+ return {
+ "original_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ "processed_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ "winner": "tie",
+ "improvement": 0.0,
+ "reasoning": f"Error occurred: {e!s}",
+ }
+
+ def analyze_memory_processing_effectiveness(
+ self,
+ bad_cases: list[dict[str, Any]],
+ processing_types: list[str] | None = None,
+ ) -> dict[str, Any]:
+ """
+ Analyze the effectiveness of different memory processing techniques.
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ processing_types: List of processing types to test
+
+ Returns:
+ Dictionary containing comprehensive analysis results
+ """
+ if processing_types is None:
+ processing_types = ["summarize", "restructure", "enhance"]
+ results = {"processing_results": [], "statistics": {}, "processing_types": processing_types}
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Processing case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ case_result = {
+ "case_id": i,
+ "query": case["query"],
+ "golden_answer": case["golden_answer"],
+ "original_memories": case["memories"],
+ "processing_results": {},
+ }
+
+ # Generate answer with original memories
+ original_answer_result = self.generate_answer_with_memories(
+ case["query"], case["memories"], "original"
+ )
+ case_result["original_answer"] = original_answer_result
+
+ # Test each processing type
+ for processing_type in processing_types:
+ logger.info(f" Testing {processing_type} processing...")
+
+ # Process memories
+ processing_result = self.process_memories_with_llm(
+ case["memories"], case["query"], processing_type
+ )
+
+ # Generate answer with processed memories
+ processed_answer_result = self.generate_answer_with_memories(
+ case["query"],
+ processing_result["processed_memories"],
+ f"processed_{processing_type}",
+ )
+
+ # Compare answer quality
+ comparison_result = self.compare_answer_quality(
+ case["query"],
+ case["golden_answer"],
+ original_answer_result["answer"],
+ processed_answer_result["answer"],
+ )
+
+ case_result["processing_results"][processing_type] = {
+ "processing": processing_result,
+ "answer": processed_answer_result,
+ "comparison": comparison_result,
+ }
+
+ results["processing_results"].append(case_result)
+
+ # Calculate statistics
+ self._calculate_processing_statistics(results)
+
+ return results
+
+ def _calculate_processing_statistics(self, results: dict[str, Any]) -> None:
+ """Calculate statistics for processing effectiveness analysis."""
+ processing_types = results["processing_types"]
+ processing_results = results["processing_results"]
+
+ if not processing_results:
+ results["statistics"] = {}
+ return
+
+ stats = {"total_cases": len(processing_results), "processing_type_stats": {}}
+
+ for processing_type in processing_types:
+ type_stats = {
+ "wins": 0,
+ "ties": 0,
+ "losses": 0,
+ "avg_improvement": 0.0,
+ "avg_compression_ratio": 0.0,
+ "avg_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ }
+
+ valid_cases = []
+ for case in processing_results:
+ if processing_type in case["processing_results"]:
+ result = case["processing_results"][processing_type]
+ comparison = result["comparison"]
+
+ # Count wins/ties/losses
+ if comparison["winner"] == "processed":
+ type_stats["wins"] += 1
+ elif comparison["winner"] == "tie":
+ type_stats["ties"] += 1
+ else:
+ type_stats["losses"] += 1
+
+ valid_cases.append(result)
+
+ if valid_cases:
+ # Calculate averages
+ type_stats["avg_improvement"] = sum(
+ case["comparison"]["improvement"] for case in valid_cases
+ ) / len(valid_cases)
+
+ type_stats["avg_compression_ratio"] = sum(
+ case["processing"]["compression_ratio"] for case in valid_cases
+ ) / len(valid_cases)
+
+ # Calculate average scores
+ for score_type in type_stats["avg_scores"]:
+ type_stats["avg_scores"][score_type] = sum(
+ case["comparison"]["processed_scores"][score_type] for case in valid_cases
+ ) / len(valid_cases)
+
+ # Calculate win rate
+ total_decisions = type_stats["wins"] + type_stats["ties"] + type_stats["losses"]
+ type_stats["win_rate"] = (
+ type_stats["wins"] / total_decisions if total_decisions > 0 else 0.0
+ )
+ type_stats["success_rate"] = (
+ (type_stats["wins"] + type_stats["ties"]) / total_decisions
+ if total_decisions > 0
+ else 0.0
+ )
+
+ stats["processing_type_stats"][processing_type] = type_stats
+
+ results["statistics"] = stats
+
+ def analyze_bad_cases(self, bad_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """
+ Analyze all bad cases to determine memory sufficiency.
+
+ Args:
+ bad_cases: List of bad cases to analyze
+
+ Returns:
+ List of analyzed bad cases with sufficiency information
+ """
+ analyzed_cases = []
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Analyzing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ analysis = self.analyze_memory_sufficiency(
+ case["query"], case["golden_answer"], case["memories"]
+ )
+
+ # Add analysis results to the case
+ analyzed_case = case.copy()
+ analyzed_case.update(
+ {
+ "memory_analysis": analysis,
+ "has_sufficient_memories": analysis["sufficient"],
+ "analysis_confidence": analysis["confidence"],
+ "relevant_memory_count": len(analysis["relevant_memories"]),
+ }
+ )
+
+ analyzed_cases.append(analyzed_case)
+
+ return analyzed_cases
+
+ def collect_bad_cases(self, eval_result_dir: str | None = None) -> dict[str, Any]:
+ """
+ Main method to collect and analyze bad cases from evaluation results.
+
+ Args:
+ eval_result_dir: Directory containing evaluation results
+
+ Returns:
+ Dictionary containing analysis results and statistics
+ """
+ if eval_result_dir is None:
+ eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-072005-fast"
+
+ judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json")
+ search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json")
+
+ # Extract bad cases
+ bad_cases = self.extract_bad_cases(judged_file, search_results_file)
+
+ if not bad_cases:
+ logger.warning("No bad cases found")
+ return {"bad_cases": [], "statistics": {}}
+
+ # Analyze bad cases
+ analyzed_cases = self.analyze_bad_cases(bad_cases)
+
+ # Calculate statistics
+ total_cases = len(analyzed_cases)
+ sufficient_cases = sum(
+ 1 for case in analyzed_cases if case.get("has_sufficient_memories", False)
+ )
+ insufficient_cases = total_cases - sufficient_cases
+
+ avg_confidence = (
+ sum(case["analysis_confidence"] for case in analyzed_cases) / total_cases
+ if total_cases > 0
+ else 0
+ )
+ avg_relevant_memories = (
+ sum(case["relevant_memory_count"] for case in analyzed_cases) / total_cases
+ if total_cases > 0
+ else 0
+ )
+
+ statistics = {
+ "total_bad_cases": total_cases,
+ "sufficient_memory_cases": sufficient_cases,
+ "insufficient_memory_cases": insufficient_cases,
+ "sufficiency_rate": sufficient_cases / total_cases if total_cases > 0 else 0,
+ "average_confidence": avg_confidence,
+ "average_relevant_memories": avg_relevant_memories,
+ }
+
+ # Save results
+ results = {
+ "bad_cases": analyzed_cases,
+ "statistics": statistics,
+ "metadata": {
+ "eval_result_dir": eval_result_dir,
+ "judged_file": judged_file,
+ "search_results_file": search_results_file,
+ "analysis_model": self.openai_model,
+ },
+ }
+
+ output_file = self.output_dir / "bad_cases_analysis.json"
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+
+ logger.info(f"Analysis complete. Results saved to: {output_file}")
+ logger.info(f"Statistics: {statistics}")
+
+ return results
+
+ def _parse_json_response(self, response_text: str) -> dict:
+ """
+ Parse JSON response from LLM, handling various formats and potential errors.
+
+ Args:
+ response_text: Raw response text from LLM
+
+ Returns:
+ Parsed JSON dictionary
+
+ Raises:
+ ValueError: If JSON cannot be parsed
+ """
+ import re
+
+ # Try to extract JSON from response text
+ # Look for JSON blocks between ```json and ``` or just {} blocks
+ json_patterns = [r"```json\s*(\{.*?\})\s*```", r"```\s*(\{.*?\})\s*```", r"(\{.*\})"]
+
+ for pattern in json_patterns:
+ matches = re.findall(pattern, response_text, re.DOTALL)
+ if matches:
+ json_str = matches[0].strip()
+ try:
+ return json.loads(json_str)
+ except json.JSONDecodeError:
+ continue
+
+ # If no JSON pattern found, try parsing the entire response
+ try:
+ return json.loads(response_text.strip())
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON response: {response_text[:200]}...")
+ raise ValueError(f"Invalid JSON response: {e!s}") from e
+
+ def filter_memories_with_llm(self, memories: list[str], query: str) -> tuple[list[str], bool]:
+ """
+ Use LLM to filter memories based on relevance to the query.
+
+ Args:
+ memories: List of memory strings
+ query: Query to filter memories against
+
+ Returns:
+ Tuple of (filtered_memories, success_flag)
+ """
+ if not memories:
+ return [], True
+
+ # Build prompt for memory filtering
+ memories_text = "\n".join([f"{i + 1}. {memory}" for i, memory in enumerate(memories)])
+
+ prompt = f"""You are a memory filtering system. Given a query and a list of memories, identify which memories are relevant and non-redundant for answering the query.
+
+Query: {query}
+
+Memories:
+{memories_text}
+
+Please analyze each memory and return a JSON response with the following format:
+{{
+ "relevant_memory_indices": [list of indices (1-based) of memories that are relevant to the query],
+ "reasoning": "Brief explanation of your filtering decisions"
+}}
+
+Only include memories that are directly relevant to answering the query. Remove redundant or unrelated memories."""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ )
+
+ response_text = response.choices[0].message.content
+
+ # Extract JSON from response
+ result = self._parse_json_response(response_text)
+
+ if "relevant_memory_indices" in result:
+ relevant_indices = result["relevant_memory_indices"]
+ filtered_memories = []
+
+ for idx in relevant_indices:
+ if 1 <= idx <= len(memories):
+ filtered_memories.append(memories[idx - 1])
+
+ logger.info(f"Filtered memories: {len(memories)} -> {len(filtered_memories)}")
+ return filtered_memories, True
+ else:
+ logger.warning("Invalid response format from memory filtering LLM")
+ return memories, False
+
+ except Exception as e:
+ logger.error(f"Error in memory filtering: {e}")
+ return memories, False
+
+ def evaluate_answer_ability_with_llm(self, query: str, memories: list[str]) -> bool:
+ """
+ Use LLM to evaluate whether the given memories can answer the query.
+
+ Args:
+ query: Query to evaluate
+ memories: List of memory strings
+
+ Returns:
+ Boolean indicating whether memories can answer the query
+ """
+ if not memories:
+ return False
+
+ memories_text = "\n".join([f"- {memory}" for memory in memories])
+
+ prompt = f"""You are an answer ability evaluator. Given a query and a list of memories, determine whether the memories contain sufficient information to answer the query.
+
+Query: {query}
+
+Available Memories:
+{memories_text}
+
+Please analyze the memories and return a JSON response with the following format:
+{{
+ "can_answer": true/false,
+ "confidence": 0.0-1.0,
+ "reasoning": "Brief explanation of your decision"
+}}
+
+Consider whether the memories contain the specific information needed to provide a complete and accurate answer to the query."""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ )
+
+ response_text = response.choices[0].message.content
+ result = self._parse_json_response(response_text)
+
+ if "can_answer" in result:
+ can_answer = result["can_answer"]
+ confidence = result.get("confidence", 0.5)
+ reasoning = result.get("reasoning", "No reasoning provided")
+
+ logger.info(
+ f"Answer ability evaluation: {can_answer} (confidence: {confidence:.2f}) - {reasoning}"
+ )
+ return can_answer
+ else:
+ logger.warning("Invalid response format from answer ability evaluation")
+ return False
+
+ except Exception as e:
+ logger.error(f"Error in answer ability evaluation: {e}")
+ return False
+
+ def memory_llm_processing_analysis(
+ self, bad_cases: list[dict[str, Any]], use_llm_filtering: bool = True
+ ) -> list[dict[str, Any]]:
+ """
+ Analyze bad cases by processing memories with LLM filtering and testing answer ability.
+
+ This method:
+ 1. Parses memory strings from bad cases
+ 2. Uses LLM to filter unrelated and redundant memories
+ 3. Tests whether processed memories can help answer questions correctly
+ 4. Compares results before and after LLM processing
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ use_llm_filtering: Whether to use LLM filtering
+
+ Returns:
+ List of analyzed bad cases with LLM processing results
+ """
+ analyzed_cases = []
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Processing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ try:
+ # Parse memory string
+ memories_text = case.get("memories", "")
+ if not memories_text:
+ logger.warning(f"No memories found for case {i + 1}")
+ analyzed_case = case.copy()
+ analyzed_case.update(
+ {
+ "llm_processing_analysis": {
+ "error": "No memories available",
+ "original_memories_count": 0,
+ "processed_memories_count": 0,
+ "can_answer_with_original": False,
+ "can_answer_with_processed": False,
+ "processing_improved_answer": False,
+ }
+ }
+ )
+ analyzed_cases.append(analyzed_case)
+ continue
+
+ # Split memories by lines
+ memory_lines = [line.strip() for line in memories_text.split("\n") if line.strip()]
+ original_memories = [line for line in memory_lines if line]
+
+ logger.info(f"Parsed {len(original_memories)} memories from text")
+
+ # Test answer ability with original memories
+ can_answer_original = self.evaluate_answer_ability_with_llm(
+ query=case["query"], memories=original_memories
+ )
+
+ # Process memories with LLM filtering if enabled
+ processed_memories = original_memories
+ processing_success = False
+
+ if use_llm_filtering and len(original_memories) > 0:
+ processed_memories, processing_success = self.filter_memories_with_llm(
+ memories=original_memories, query=case["query"]
+ )
+ logger.info(
+ f"LLM filtering: {len(original_memories)} -> {len(processed_memories)} memories, success: {processing_success}"
+ )
+
+ # Test answer ability with processed memories
+ can_answer_processed = self.evaluate_answer_ability_with_llm(
+ query=case["query"], memories=processed_memories
+ )
+
+ # Determine if processing improved answer ability
+ processing_improved = can_answer_processed and not can_answer_original
+
+ # Create analysis result
+ llm_analysis = {
+ "processing_success": processing_success,
+ "original_memories_count": len(original_memories),
+ "processed_memories_count": len(processed_memories),
+ "memories_removed_count": len(original_memories) - len(processed_memories),
+ "can_answer_with_original": can_answer_original,
+ "can_answer_with_processed": can_answer_processed,
+ "processing_improved_answer": processing_improved,
+ "original_memories": original_memories,
+ "processed_memories": processed_memories,
+ }
+
+ # Add analysis to case
+ analyzed_case = case.copy()
+ analyzed_case["llm_processing_analysis"] = llm_analysis
+
+ logger.info(
+ f"Case {i + 1} analysis complete: "
+ f"Original: {can_answer_original}, "
+ f"Processed: {can_answer_processed}, "
+ f"Improved: {processing_improved}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing case {i + 1}: {e}")
+ analyzed_case = case.copy()
+ analyzed_case["llm_processing_analysis"] = {
+ "error": str(e),
+ "processing_success": False,
+ "original_memories_count": 0,
+ "processed_memories_count": 0,
+ "can_answer_with_original": False,
+ "can_answer_with_processed": False,
+ "processing_improved_answer": False,
+ }
+
+ analyzed_cases.append(analyzed_case)
+
+ return analyzed_cases
+
+ def scheduler_mem_process(self, query, memories):
+ from memos.mem_scheduler.utils.misc_utils import extract_list_items_in_answer
+
+ _memories = []
+ for mem in memories:
+ mem_item = TextualMemoryItem(memory=mem, metadata=TextualMemoryMetadata())
+ _memories.append(mem_item)
+ prompt = mem_scheduler.retriever._build_enhancement_prompt(
+ query_history=[query], batch_texts=memories
+ )
+ logger.debug(
+ f"[Enhance][batch={0}] Prompt (first 200 chars, len={len(prompt)}): {prompt[:200]}..."
+ )
+
+ response = mem_scheduler.retriever.process_llm.generate(
+ [{"role": "user", "content": prompt}]
+ )
+ logger.debug(f"[Enhance][batch={0}] Response (first 200 chars): {response[:200]}...")
+
+ processed_results = extract_list_items_in_answer(response)
+
+ return {
+ "processed_memories": processed_results,
+ "processing_type": "enhance",
+ "original_length": len("\n".join(memories)),
+ "processed_length": len("\n".join(processed_results)),
+ "compression_ratio": len("\n".join(processed_results)) / len("\n".join(memories))
+ if len(memories) > 0
+ else 0,
+ }
+
+ def analyze_bad_cases_with_llm_processing(
+ self,
+ bad_cases: list[dict[str, Any]],
+ save_results: bool = True,
+ output_file: str | None = None,
+ ) -> dict[str, Any]:
+ """
+ Comprehensive analysis of bad cases with LLM memory processing.
+
+ This method performs a complete analysis including:
+ 1. Basic bad case analysis
+ 2. LLM memory processing analysis
+ 3. Statistical summary of improvements
+ 4. Detailed reporting
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ save_results: Whether to save results to file
+ output_file: Optional output file path
+
+ Returns:
+ Dictionary containing comprehensive analysis results
+ """
+ from datetime import datetime
+
+ logger.info(
+ f"Starting comprehensive analysis of {len(bad_cases)} bad cases with LLM processing"
+ )
+
+ # Perform LLM memory processing analysis
+ analyzed_cases = self.memory_llm_processing_analysis(
+ bad_cases=bad_cases, use_llm_filtering=True
+ )
+
+ # Calculate statistics
+ total_cases = len(analyzed_cases)
+ successful_processing = 0
+ improved_cases = 0
+ original_answerable = 0
+ processed_answerable = 0
+ total_memories_before = 0
+ total_memories_after = 0
+
+ for case in analyzed_cases:
+ llm_analysis = case.get("llm_processing_analysis", {})
+
+ if llm_analysis.get("processing_success", False):
+ successful_processing += 1
+
+ if llm_analysis.get("processing_improved_answer", False):
+ improved_cases += 1
+
+ if llm_analysis.get("can_answer_with_original", False):
+ original_answerable += 1
+
+ if llm_analysis.get("can_answer_with_processed", False):
+ processed_answerable += 1
+
+ total_memories_before += llm_analysis.get("original_memories_count", 0)
+ total_memories_after += llm_analysis.get("processed_memories_count", 0)
+
+ # Calculate improvement metrics
+ processing_success_rate = successful_processing / total_cases if total_cases > 0 else 0
+ improvement_rate = improved_cases / total_cases if total_cases > 0 else 0
+ original_answer_rate = original_answerable / total_cases if total_cases > 0 else 0
+ processed_answer_rate = processed_answerable / total_cases if total_cases > 0 else 0
+ memory_reduction_rate = (
+ (total_memories_before - total_memories_after) / total_memories_before
+ if total_memories_before > 0
+ else 0
+ )
+
+ # Create comprehensive results
+ results = {
+ "analysis_metadata": {
+ "total_cases_analyzed": total_cases,
+ "analysis_timestamp": datetime.now().isoformat(),
+ "llm_model_used": self.openai_model,
+ },
+ "processing_statistics": {
+ "successful_processing_count": successful_processing,
+ "processing_success_rate": processing_success_rate,
+ "cases_with_improvement": improved_cases,
+ "improvement_rate": improvement_rate,
+ "original_answerable_cases": original_answerable,
+ "original_answer_rate": original_answer_rate,
+ "processed_answerable_cases": processed_answerable,
+ "processed_answer_rate": processed_answer_rate,
+ "answer_rate_improvement": processed_answer_rate - original_answer_rate,
+ },
+ "memory_statistics": {
+ "total_memories_before_processing": total_memories_before,
+ "total_memories_after_processing": total_memories_after,
+ "memories_removed": total_memories_before - total_memories_after,
+ "memory_reduction_rate": memory_reduction_rate,
+ "average_memories_per_case_before": total_memories_before / total_cases
+ if total_cases > 0
+ else 0,
+ "average_memories_per_case_after": total_memories_after / total_cases
+ if total_cases > 0
+ else 0,
+ },
+ "analyzed_cases": analyzed_cases,
+ }
+
+ # Log summary
+ logger.info("LLM Processing Analysis Summary:")
+ logger.info(f" - Total cases: {total_cases}")
+ logger.info(f" - Processing success rate: {processing_success_rate:.2%}")
+ logger.info(f" - Cases with improvement: {improved_cases} ({improvement_rate:.2%})")
+ logger.info(f" - Original answer rate: {original_answer_rate:.2%}")
+ logger.info(f" - Processed answer rate: {processed_answer_rate:.2%}")
+ logger.info(
+ f" - Answer rate improvement: {processed_answer_rate - original_answer_rate:.2%}"
+ )
+ logger.info(f" - Memory reduction: {memory_reduction_rate:.2%}")
+
+ # Save results if requested
+ if save_results:
+ if output_file is None:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_file = f"llm_processing_analysis_{timestamp}.json"
+
+ try:
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+ logger.info(f"Analysis results saved to: {output_file}")
+ except Exception as e:
+ logger.error(f"Failed to save results to {output_file}: {e}")
+
+ return results
+
+
+def main():
+ """Main test function."""
+ print("=== EvalAnalyzer Simple Test ===")
+
+ # Initialize analyzer
+ analyzer = EvalAnalyzer(output_dir="./tmp/eval_analyzer")
+
+ print("Analyzer initialized")
+
+ # Test file paths
+ eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-xcy-1030-2114-locomo"
+ judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json")
+ search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json")
+
+ print("Testing with files:")
+ print(f" Judged file: {judged_file}")
+ print(f" Search results file: {search_results_file}")
+
+ # Check if files exist
+ if not os.path.exists(judged_file):
+ print(f"โ Judged file not found: {judged_file}")
+ return
+
+ if not os.path.exists(search_results_file):
+ print(f"โ Search results file not found: {search_results_file}")
+ return
+
+ print("โ
Both files exist")
+
+ # Test bad case extraction only
+ try:
+ print("\n=== Testing Bad Case Extraction ===")
+ bad_cases = analyzer.extract_bad_cases(judged_file, search_results_file)
+
+ print(f"โ
Successfully extracted {len(bad_cases)} bad cases")
+
+ if bad_cases:
+ print("\n=== Sample Bad Cases ===")
+ for i, case in enumerate(bad_cases[:3]): # Show first 3 cases
+ print(f"\nBad Case {i + 1}:")
+ print(f" User ID: {case['user_id']}")
+ print(f" Query: {case['query'][:100]}...")
+ print(f" Golden Answer: {case['golden_answer']}...")
+ print(f" Answer: {case['answer']}...")
+ print(f" Has Memories: {len(case['memories']) > 0}")
+ print(f" Memory Length: {len(case['memories'])} chars")
+
+ # Save basic results without LLM analysis
+ basic_results = {
+ "bad_cases_count": len(bad_cases),
+ "bad_cases": bad_cases,
+ "metadata": {
+ "eval_result_dir": eval_result_dir,
+ "judged_file": judged_file,
+ "search_results_file": search_results_file,
+ "extraction_only": True,
+ },
+ }
+
+ output_file = analyzer.output_dir / "bad_cases_extraction_only.json"
+ import json
+
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(basic_results, f, indent=2, ensure_ascii=False)
+
+ print(f"\nโ
Basic extraction results saved to: {output_file}")
+
+ except Exception as e:
+ print(f"โ Error during extraction: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/memos/mem_scheduler/analyzer/memory_processing.py b/src/memos/mem_scheduler/analyzer/memory_processing.py
new file mode 100644
index 000000000..b692341c2
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/memory_processing.py
@@ -0,0 +1,246 @@
+#!/usr/bin/env python3
+"""
+Test script for memory processing functionality in eval_analyzer.py
+
+This script demonstrates how to use the new LLM memory processing features
+to analyze and improve memory-based question answering.
+"""
+
+import json
+import os
+import sys
+
+from pathlib import Path
+from typing import Any
+
+from memos.log import get_logger
+from memos.mem_scheduler.analyzer.eval_analyzer import EvalAnalyzer
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent # Go up to project root
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+
+logger = get_logger(__name__)
+
+
+def create_sample_bad_cases() -> list[dict[str, Any]]:
+ """Create sample bad cases for testing memory processing."""
+ return [
+ {
+ "query": "What is the capital of France?",
+ "golden_answer": "Paris",
+ "memories": """
+ Memory 1: France is a country in Western Europe.
+ Memory 2: The Eiffel Tower is located in Paris.
+ Memory 3: Paris is known for its art museums and fashion.
+ Memory 4: French cuisine is famous worldwide.
+ Memory 5: The Seine River flows through Paris.
+ """,
+ },
+ {
+ "query": "When was the iPhone first released?",
+ "golden_answer": "June 29, 2007",
+ "memories": """
+ Memory 1: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne.
+ Memory 2: The iPhone was announced by Steve Jobs at the Macworld Conference & Expo on January 9, 2007.
+ Memory 3: The iPhone went on sale on June 29, 2007.
+ Memory 4: The original iPhone had a 3.5-inch screen.
+ Memory 5: Apple's stock price increased significantly after the iPhone launch.
+ """,
+ },
+ {
+ "query": "What is photosynthesis?",
+ "golden_answer": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.",
+ "memories": """
+ Memory 1: Plants are living organisms that need sunlight to grow.
+ Memory 2: Chlorophyll is the green pigment in plants.
+ Memory 3: Plants take in carbon dioxide from the air.
+ Memory 4: Water is absorbed by plant roots from the soil.
+ Memory 5: Oxygen is released by plants during the day.
+ Memory 6: Glucose is a type of sugar that plants produce.
+ """,
+ },
+ ]
+
+
+def memory_processing(bad_cases):
+ """
+ Test the memory processing functionality with cover rate and acc rate analysis.
+
+ This function analyzes:
+ 1. Cover rate: Whether memories contain all information needed to answer the query
+ 2. Acc rate: Whether processed memories can correctly answer the query
+ """
+ print("๐งช Testing Memory Processing Functionality with Cover Rate & Acc Rate Analysis")
+ print("=" * 80)
+
+ # Initialize analyzer
+ analyzer = EvalAnalyzer()
+
+ print(f"๐ Testing with {len(bad_cases)} sample cases")
+ print()
+
+ # Initialize counters for real-time statistics
+ total_cases = 0
+ cover_count = 0 # Cases where memories cover all needed information
+ acc_count = 0 # Cases where processed memories can correctly answer
+
+ # Process each case
+ for i, case in enumerate(bad_cases):
+ total_cases += 1
+
+ # Safely handle query display
+ query_display = str(case.get("query", "Unknown query"))
+ print(f"๐ Case {i + 1}/{len(bad_cases)}: {query_display}...")
+
+ # Safely handle golden_answer display (convert to string if needed)
+ golden_answer = case.get("golden_answer", "Unknown answer")
+ golden_answer_str = str(golden_answer) if golden_answer is not None else "Unknown answer"
+ print(f"๐ Golden Answer: {golden_answer_str}")
+ print()
+
+ # Step 1: Analyze if memories contain sufficient information (Cover Rate)
+ print(" ๐ Step 1: Analyzing memory coverage...")
+ coverage_analysis = analyzer.analyze_memory_sufficiency(
+ case["query"],
+ golden_answer_str, # Use the string version
+ case["memories"],
+ )
+
+ has_coverage = coverage_analysis.get("sufficient", False)
+ if has_coverage:
+ cover_count += 1
+
+ print(f" โ
Memory Coverage: {'SUFFICIENT' if has_coverage else 'INSUFFICIENT'}")
+ print(f" ๐ฏ Confidence: {coverage_analysis.get('confidence', 0):.2f}")
+ print(f" ๐ญ Reasoning: {coverage_analysis.get('reasoning', 'N/A')}...")
+ if not has_coverage:
+ print(
+ f" โ Missing Info: {coverage_analysis.get('missing_information', 'N/A')[:100]}..."
+ )
+ continue
+ print()
+
+ # Step 2: Process memories and test answer ability (Acc Rate)
+ print(" ๐ Step 2: Processing memories and testing answer ability...")
+
+ processing_result = analyzer.scheduler_mem_process(
+ query=case["query"],
+ memories=case["memories"],
+ )
+ print(f"Original Memories: {case['memories']}")
+ print(f"Processed Memories: {processing_result['processed_memories']}")
+ print(f" ๐ Compression ratio: {processing_result['compression_ratio']:.2f}")
+ print(f" ๐ Processed memories length: {processing_result['processed_length']} chars")
+
+ # Generate answer with processed memories
+ answer_result = analyzer.generate_answer_with_memories(
+ case["query"], processing_result["processed_memories"], "processed_enhanced"
+ )
+
+ # Evaluate if the generated answer is correct
+ print(" ๐ฏ Step 3: Evaluating answer correctness...")
+ answer_evaluation = analyzer.compare_answer_quality(
+ case["query"],
+ golden_answer_str, # Use the string version
+ "No original answer available", # We don't have original answer
+ answer_result["answer"],
+ )
+
+ # Determine if processed memories can correctly answer (simplified logic)
+ processed_accuracy = answer_evaluation.get("processed_scores", {}).get("accuracy", 0)
+ can_answer_correctly = processed_accuracy >= 0.7 # Threshold for "correct" answer
+
+ if can_answer_correctly:
+ acc_count += 1
+
+ print(f" ๐ฌ Generated Answer: {answer_result['answer']}...")
+ print(
+ f" โ
Answer Accuracy: {'CORRECT' if can_answer_correctly else 'INCORRECT'} (score: {processed_accuracy:.2f})"
+ )
+ print()
+
+ # Calculate and print real-time rates
+ current_cover_rate = cover_count / total_cases
+ current_acc_rate = acc_count / total_cases
+
+ print(" ๐ REAL-TIME STATISTICS:")
+ print(f" ๐ฏ Cover Rate: {current_cover_rate:.2%} ({cover_count}/{total_cases})")
+ print(f" โ
Acc Rate: {current_acc_rate:.2%} ({acc_count}/{total_cases})")
+ print()
+
+ print("-" * 80)
+ print()
+
+ # Final summary
+ print("๐ FINAL ANALYSIS SUMMARY")
+ print("=" * 80)
+ print(f"๐ Total Cases Processed: {total_cases}")
+ print(f"๐ฏ Final Cover Rate: {cover_count / total_cases:.2%} ({cover_count}/{total_cases})")
+ print(f" - Cases with sufficient memory coverage: {cover_count}")
+ print(f" - Cases with insufficient memory coverage: {total_cases - cover_count}")
+ print()
+ print(f"โ
Final Acc Rate: {acc_count / total_cases:.2%} ({acc_count}/{total_cases})")
+ print(f" - Cases where processed memories can answer correctly: {acc_count}")
+ print(f" - Cases where processed memories cannot answer correctly: {total_cases - acc_count}")
+ print()
+
+ # Additional insights
+ if cover_count > 0:
+ effective_processing_rate = acc_count / cover_count if cover_count > 0 else 0
+ print(f"๐ Processing Effectiveness: {effective_processing_rate:.2%}")
+ print(
+ f" - Among cases with sufficient coverage, {effective_processing_rate:.1%} can be answered correctly after processing"
+ )
+
+ print("=" * 80)
+
+
+def load_real_bad_cases(file_path: str) -> list[dict[str, Any]]:
+ """Load real bad cases from JSON file."""
+ print(f"๐ Loading bad cases from: {file_path}")
+
+ with open(file_path, encoding="utf-8") as f:
+ data = json.load(f)
+
+ bad_cases = data.get("bad_cases", [])
+ print(f"โ
Loaded {len(bad_cases)} bad cases")
+
+ return bad_cases
+
+
+def main():
+ """Main test function."""
+ print("๐ Memory Processing Test Suite")
+ print("=" * 60)
+ print()
+
+ # Check if OpenAI API key is set
+ if not os.getenv("OPENAI_API_KEY"):
+ print("โ ๏ธ Warning: OPENAI_API_KEY not found in environment variables")
+ print(" Please set your OpenAI API key to run the tests")
+ return
+
+ try:
+ bad_cases_file = f"{BASE_DIR}/tmp/eval_analyzer/bad_cases_extraction_only.json"
+ bad_cases = load_real_bad_cases(bad_cases_file)
+
+ print(f"โ
Created {len(bad_cases)} sample bad cases")
+ print()
+
+ # Run memory processing tests
+ memory_processing(bad_cases)
+
+ print("โ
All tests completed successfully!")
+
+ except Exception as e:
+ print(f"โ Test failed with error: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
index ace67eff6..03e1fc778 100644
--- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
+++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
@@ -427,7 +427,6 @@ def chat(self, query: str, user_id: str | None = None) -> str:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=QUERY_LABEL,
content=query,
timestamp=datetime.now(),
@@ -518,7 +517,6 @@ def chat(self, query: str, user_id: str | None = None) -> str:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=response,
timestamp=datetime.now(),
diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
index 7c0fa5a4a..3d0235871 100644
--- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
+++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
@@ -226,9 +226,9 @@ def evaluate_memory_answer_ability(
try:
# Extract JSON response
- from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+ from memos.mem_scheduler.utils.misc_utils import extract_json_obj
- result = extract_json_dict(response)
+ result = extract_json_obj(response)
# Validate response structure
if "result" in result:
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 028fe8e3f..eb49d0238 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -1,6 +1,5 @@
import contextlib
import multiprocessing
-import queue
import threading
import time
@@ -18,15 +17,18 @@
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
+from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_ACT_MEM_DUMP_PATH,
+ DEFAULT_CONSUME_BATCH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
+ DEFAULT_MAX_WEB_LOG_QUEUE_SIZE,
DEFAULT_STARTUP_MODE,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
@@ -86,6 +88,22 @@ def __init__(self, config: BaseSchedulerConfig):
"scheduler_startup_mode", DEFAULT_STARTUP_MODE
)
+ # message queue configuration
+ self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE)
+ self.max_internal_message_queue_size = self.config.get(
+ "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
+ )
+
+ # Initialize message queue based on configuration
+ if self.use_redis_queue:
+ self.memos_message_queue = SchedulerRedisQueue(
+ maxsize=self.max_internal_message_queue_size
+ )
+ else:
+ self.memos_message_queue: Queue[ScheduleMessageItem] = Queue(
+ maxsize=self.max_internal_message_queue_size
+ )
+
self.retriever: SchedulerRetriever | None = None
self.db_engine: Engine | None = None
self.monitor: SchedulerGeneralMonitor | None = None
@@ -93,6 +111,8 @@ def __init__(self, config: BaseSchedulerConfig):
self.mem_reader = None # Will be set by MOSCore
self.dispatcher = SchedulerDispatcher(
config=self.config,
+ memos_message_queue=self.memos_message_queue,
+ use_redis_queue=self.use_redis_queue,
max_workers=self.thread_pool_max_workers,
enable_parallel_dispatch=self.enable_parallel_dispatch,
)
@@ -100,23 +120,9 @@ def __init__(self, config: BaseSchedulerConfig):
# optional configs
self.disable_handlers: list | None = self.config.get("disable_handlers", None)
- # message queue configuration
- self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE)
- self.max_internal_message_queue_size = self.config.get(
- "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
+ self.max_web_log_queue_size = self.config.get(
+ "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
)
-
- # Initialize message queue based on configuration
- if self.use_redis_queue:
- self.memos_message_queue = None # Will use Redis instead
- # Initialize Redis if using Redis queue with auto-initialization
- self.auto_initialize_redis()
- else:
- self.memos_message_queue: Queue[ScheduleMessageItem] = Queue(
- maxsize=self.max_internal_message_queue_size
- )
-
- self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50)
self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
maxsize=self.max_web_log_queue_size
)
@@ -126,6 +132,7 @@ def __init__(self, config: BaseSchedulerConfig):
self._consume_interval = self.config.get(
"consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS
)
+ self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH)
# other attributes
self._context_lock = threading.Lock()
@@ -216,7 +223,7 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
with self._context_lock:
self.current_user_id = msg.user_id
self.current_mem_cube_id = msg.mem_cube_id
- self.current_mem_cube = msg.mem_cube
+ self.current_mem_cube = self.get_mem_cube(msg.mem_cube_id)
def transform_working_memories_to_monitors(
self, query_keywords, memories: list[TextualMemoryItem]
@@ -533,17 +540,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
if self.disable_handlers and message.label in self.disable_handlers:
logger.info(f"Skipping disabled handler: {message.label} - {message.content}")
continue
+ self.memos_message_queue.put(message)
+ logger.info(f"Submitted message to local queue: {message.label} - {message.content}")
- if self.use_redis_queue:
- # Use Redis stream for message queue
- self.redis_add_message_stream(message.to_dict())
- logger.info(f"Submitted message to Redis: {message.label} - {message.content}")
- else:
- # Use local queue
- self.memos_message_queue.put(message)
- logger.info(
- f"Submitted message to local queue: {message.label} - {message.content}"
- )
with contextlib.suppress(Exception):
if messages:
self.dispatcher.on_messages_enqueued(messages)
@@ -590,7 +589,7 @@ def get_web_log_messages(self) -> list[dict]:
try:
item = self._web_log_message_queue.get_nowait() # Thread-safe get
messages.append(item.to_dict())
- except queue.Empty:
+ except Exception:
break
return messages
@@ -601,62 +600,28 @@ def _message_consumer(self) -> None:
Runs in a dedicated thread to process messages at regular intervals.
For Redis queue, this method starts the Redis listener.
"""
- if self.use_redis_queue:
- # For Redis queue, start the Redis listener
- def redis_message_handler(message_data):
- """Handler for Redis messages"""
- try:
- # Redis message data needs to be decoded from bytes to string
- decoded_data = {}
- for key, value in message_data.items():
- if isinstance(key, bytes):
- key = key.decode("utf-8")
- if isinstance(value, bytes):
- value = value.decode("utf-8")
- decoded_data[key] = value
-
- message = ScheduleMessageItem.from_dict(decoded_data)
- self.dispatcher.dispatch([message])
- except Exception as e:
- logger.error(f"Error processing Redis message: {e}")
- logger.error(f"Message data: {message_data}")
-
- self.redis_start_listening(handler=redis_message_handler)
-
- # Keep the thread alive while Redis listener is running
- while self._running:
- time.sleep(self._consume_interval)
- else:
- # Original local queue logic
- while self._running: # Use a running flag for graceful shutdown
- try:
- # Get all available messages at once (thread-safe approach)
- messages = []
- while True:
- try:
- # Use get_nowait() directly without empty() check to avoid race conditions
- message = self.memos_message_queue.get_nowait()
- messages.append(message)
- except queue.Empty:
- # No more messages available
- break
- if messages:
- try:
- self.dispatcher.dispatch(messages)
- except Exception as e:
- logger.error(f"Error dispatching messages: {e!s}")
- finally:
- # Mark all messages as processed
- for _ in messages:
- self.memos_message_queue.task_done()
+ # Original local queue logic
+ while self._running: # Use a running flag for graceful shutdown
+ try:
+ # Get messages in batches based on consume_batch setting
+
+ messages = self.memos_message_queue.get(block=True, batch_size=self.consume_batch)
+
+ if messages:
+ try:
+ self.dispatcher.dispatch(messages)
+ except Exception as e:
+ logger.error(f"Error dispatching messages: {e!s}")
- # Sleep briefly to prevent busy waiting
- time.sleep(self._consume_interval) # Adjust interval as needed
+ # Sleep briefly to prevent busy waiting
+ time.sleep(self._consume_interval) # Adjust interval as needed
- except Exception as e:
+ except Exception as e:
+ # Don't log error for "No messages available in Redis queue" as it's expected
+ if "No messages available in Redis queue" not in str(e):
logger.error(f"Unexpected error in message consumer: {e!s}")
- time.sleep(self._consume_interval) # Prevent tight error loops
+ time.sleep(self._consume_interval) # Prevent tight error loops
def start(self) -> None:
"""
@@ -666,16 +631,25 @@ def start(self) -> None:
1. Message consumer thread or process (based on startup_mode)
2. Dispatcher thread pool (if parallel dispatch enabled)
"""
- if self._running:
- logger.warning("Memory Scheduler is already running")
- return
-
# Initialize dispatcher resources
if self.enable_parallel_dispatch:
logger.info(
f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
)
+ self.start_consumer()
+
+ def start_consumer(self) -> None:
+ """
+ Start only the message consumer thread/process.
+
+ This method can be used to restart the consumer after it has been stopped
+ with stop_consumer(), without affecting other scheduler components.
+ """
+ if self._running:
+ logger.warning("Memory Scheduler consumer is already running")
+ return
+
# Start consumer based on startup mode
self._running = True
@@ -698,15 +672,15 @@ def start(self) -> None:
self._consumer_thread.start()
logger.info("Message consumer thread started")
- def stop(self) -> None:
- """Stop all scheduler components gracefully.
+ def stop_consumer(self) -> None:
+ """Stop only the message consumer thread/process gracefully.
- 1. Stops message consumer thread/process
- 2. Shuts down dispatcher thread pool
- 3. Cleans up resources
+ This method stops the consumer without affecting other components like
+ dispatcher or monitors. Useful when you want to pause message processing
+ while keeping other scheduler components running.
"""
if not self._running:
- logger.warning("Memory Scheduler is not running")
+ logger.warning("Memory Scheduler consumer is not running")
return
# Signal consumer thread/process to stop
@@ -726,12 +700,30 @@ def stop(self) -> None:
logger.info("Consumer process terminated")
else:
logger.info("Consumer process stopped")
+ self._consumer_process = None
elif self._consumer_thread and self._consumer_thread.is_alive():
self._consumer_thread.join(timeout=5.0)
if self._consumer_thread.is_alive():
logger.warning("Consumer thread did not stop gracefully")
else:
logger.info("Consumer thread stopped")
+ self._consumer_thread = None
+
+ logger.info("Memory Scheduler consumer stopped")
+
+ def stop(self) -> None:
+ """Stop all scheduler components gracefully.
+
+ 1. Stops message consumer thread/process
+ 2. Shuts down dispatcher thread pool
+ 3. Cleans up resources
+ """
+ if not self._running:
+ logger.warning("Memory Scheduler is not running")
+ return
+
+ # Stop consumer first
+ self.stop_consumer()
# Shutdown dispatcher
if self.dispatcher:
@@ -743,10 +735,6 @@ def stop(self) -> None:
logger.info("Shutting down monitor...")
self.dispatcher_monitor.stop()
- # Clean up queues
- self._cleanup_queues()
- logger.info("Memory Scheduler stopped completely")
-
@property
def handlers(self) -> dict[str, Callable]:
"""
@@ -819,30 +807,6 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di
return result
- def _cleanup_queues(self) -> None:
- """Ensure all queues are emptied and marked as closed."""
- if self.use_redis_queue:
- # For Redis queue, stop the listener and close connection
- try:
- self.redis_stop_listening()
- self.redis_close()
- except Exception as e:
- logger.error(f"Error cleaning up Redis connection: {e}")
- else:
- # Original local queue cleanup
- try:
- while not self.memos_message_queue.empty():
- self.memos_message_queue.get_nowait()
- self.memos_message_queue.task_done()
- except queue.Empty:
- pass
-
- try:
- while not self._web_log_message_queue.empty():
- self._web_log_message_queue.get_nowait()
- except queue.Empty:
- pass
-
def mem_scheduler_wait(
self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01
) -> bool:
@@ -906,11 +870,24 @@ def _fmt_eta(seconds: float | None) -> str:
st = (
stats_fn()
) # expected: {'pending':int,'running':int,'done':int?,'rate':float?}
- pend = int(st.get("pending", 0))
run = int(st.get("running", 0))
+
except Exception:
pass
+ if isinstance(self.memos_message_queue, SchedulerRedisQueue):
+ # For Redis queue, prefer XINFO GROUPS to compute pending
+ groups_info = self.memos_message_queue.redis.xinfo_groups(
+ self.memos_message_queue.stream_name
+ )
+ if groups_info:
+ for group in groups_info:
+ if group.get("name") == self.memos_message_queue.consumer_group:
+ pend = int(group.get("pending", pend))
+ break
+ else:
+ pend = run
+
# 2) dynamic total (allows new tasks queued while waiting)
total_now = max(init_unfinished, done_total + curr_unfinished)
done_total = max(0, total_now - curr_unfinished)
diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py
index c2407b9e6..b74529c8c 100644
--- a/src/memos/mem_scheduler/general_modules/dispatcher.py
+++ b/src/memos/mem_scheduler/general_modules/dispatcher.py
@@ -10,7 +10,9 @@
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
+from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
+from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
from memos.mem_scheduler.utils.metrics import MetricsRegistry
@@ -32,13 +34,23 @@ class SchedulerDispatcher(BaseSchedulerModule):
- Thread race competition for parallel task execution
"""
- def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
+ def __init__(
+ self,
+ max_workers: int = 30,
+ memos_message_queue: Any | None = None,
+ use_redis_queue: bool | None = None,
+ enable_parallel_dispatch: bool = True,
+ config=None,
+ ):
super().__init__()
self.config = config
# Main dispatcher thread pool
self.max_workers = max_workers
+ self.memos_message_queue = memos_message_queue
+ self.use_redis_queue = use_redis_queue
+
# Get multi-task timeout from config
self.multi_task_running_timeout = (
self.config.get("multi_task_running_timeout") if self.config else None
@@ -73,6 +85,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
self._completed_tasks = []
self.completed_tasks_max_show_size = 10
+ # Configure shutdown wait behavior from config or default
+ self.stop_wait = (
+ self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT
+ )
+
self.metrics = MetricsRegistry(
topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50)
)
@@ -131,6 +148,19 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
# --- mark done ---
for m in messages:
self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time())
+
+ # acknowledge redis messages
+
+ if (
+ self.use_redis_queue
+ and self.memos_message_queue is not None
+ and isinstance(self.memos_message_queue, SchedulerRedisQueue)
+ ):
+ for msg in messages:
+ redis_message_id = msg.redis_message_id
+ # Acknowledge message processing
+ self.memos_message_queue.ack_message(redis_message_id=redis_message_id)
+
# Mark task as completed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
@@ -138,7 +168,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
del self._running_tasks[task_item.item_id]
self._completed_tasks.append(task_item)
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
- self._completed_tasks[-self.completed_tasks_max_show_size :]
+ self._completed_tasks.pop(0)
logger.info(f"Task completed: {task_item.get_execution_info()}")
return result
@@ -152,7 +182,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
task_item.mark_failed(str(e))
del self._running_tasks[task_item.item_id]
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
- self._completed_tasks[-self.completed_tasks_max_show_size :]
+ self._completed_tasks.pop(0)
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
raise
@@ -381,17 +411,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
wrapped_handler = self._create_task_wrapper(handler, task_item)
# dispatch to different handler
- logger.debug(
- f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
- )
- logger.info(f"Task started: {task_item.get_execution_info()}")
-
+ logger.debug(f"Task started: {task_item.get_execution_info()}")
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
# Capture variables in lambda to avoid loop variable issues
- future = self.dispatcher_executor.submit(wrapped_handler, msgs)
- self._futures.add(future)
- future.add_done_callback(self._handle_future_result)
- logger.info(f"Dispatched {len(msgs)} message(s) as future task")
+ _ = self.dispatcher_executor.submit(wrapped_handler, msgs)
+ logger.info(
+ f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
+ )
else:
wrapped_handler(msgs)
@@ -484,17 +510,9 @@ def shutdown(self) -> None:
"""Gracefully shutdown the dispatcher."""
self._running = False
- if self.dispatcher_executor is not None:
- # Cancel pending tasks
- cancelled = 0
- for future in self._futures:
- if future.cancel():
- cancelled += 1
- logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks")
-
# Shutdown executor
try:
- self.dispatcher_executor.shutdown(wait=True)
+ self.dispatcher_executor.shutdown(wait=self.stop_wait, cancel_futures=True)
except Exception as e:
logger.error(f"Executor shutdown error: {e}", exc_info=True)
finally:
diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py
index b6f48d043..e4e7edb89 100644
--- a/src/memos/mem_scheduler/general_modules/misc.py
+++ b/src/memos/mem_scheduler/general_modules/misc.py
@@ -199,6 +199,9 @@ class AutoDroppingQueue(Queue[T]):
"""A thread-safe queue that automatically drops the oldest item when full."""
def __init__(self, maxsize: int = 0):
+ # If maxsize <= 0, set to 0 (unlimited queue size)
+ if maxsize <= 0:
+ maxsize = 0
super().__init__(maxsize=maxsize)
def put(self, item: T, block: bool = False, timeout: float | None = None) -> None:
@@ -218,7 +221,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
# First try non-blocking put
super().put(item, block=block, timeout=timeout)
except Full:
- # Remove oldest item and mark it done to avoid leaking unfinished_tasks
+ # Remove the oldest item and mark it done to avoid leaking unfinished_tasks
with suppress(Empty):
_ = self.get_nowait()
# If the removed item had previously incremented unfinished_tasks,
@@ -228,12 +231,70 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
# Retry putting the new item
super().put(item, block=block, timeout=timeout)
+ def get(
+ self, block: bool = True, timeout: float | None = None, batch_size: int | None = None
+ ) -> list[T] | T:
+ """Get items from the queue.
+
+ Args:
+ block: Whether to block if no items are available (default: True)
+ timeout: Timeout in seconds for blocking operations (default: None)
+ batch_size: Number of items to retrieve (default: 1)
+
+ Returns:
+ List of items (always returns a list for consistency)
+
+ Raises:
+ Empty: If no items are available and block=False or timeout expires
+ """
+
+ if batch_size is None:
+ return super().get(block=block, timeout=timeout)
+ items = []
+ for _ in range(batch_size):
+ try:
+ items.append(super().get(block=block, timeout=timeout))
+ except Empty:
+ if not items and block:
+ # If we haven't gotten any items and we're blocking, re-raise Empty
+ raise
+ break
+ return items
+
+ def get_nowait(self, batch_size: int | None = None) -> list[T]:
+ """Get items from the queue without blocking.
+
+ Args:
+ batch_size: Number of items to retrieve (default: 1)
+
+ Returns:
+ List of items (always returns a list for consistency)
+ """
+ if batch_size is None:
+ return super().get_nowait()
+
+ items = []
+ for _ in range(batch_size):
+ try:
+ items.append(super().get_nowait())
+ except Empty:
+ break
+ return items
+
def get_queue_content_without_pop(self) -> list[T]:
"""Return a copy of the queue's contents without modifying it."""
# Ensure a consistent snapshot by holding the mutex
with self.mutex:
return list(self.queue)
+ def qsize(self) -> int:
+ """Return the approximate size of the queue.
+
+ Returns:
+ Number of items currently in the queue
+ """
+ return super().qsize()
+
def clear(self) -> None:
"""Remove all items from the queue.
diff --git a/src/memos/mem_scheduler/general_modules/redis_queue.py b/src/memos/mem_scheduler/general_modules/redis_queue.py
new file mode 100644
index 000000000..c10765d05
--- /dev/null
+++ b/src/memos/mem_scheduler/general_modules/redis_queue.py
@@ -0,0 +1,460 @@
+"""
+Redis Queue implementation for SchedulerMessageItem objects.
+
+This module provides a Redis-based queue implementation that can replace
+the local memos_message_queue functionality in BaseScheduler.
+"""
+
+import time
+
+from collections.abc import Callable
+from uuid import uuid4
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerRedisQueue(RedisSchedulerModule):
+ """
+ Redis-based queue for storing and processing SchedulerMessageItem objects.
+
+ This class provides a Redis Stream-based implementation that can replace
+ the local memos_message_queue functionality, offering better scalability
+ and persistence for message processing.
+
+ Inherits from RedisSchedulerModule to leverage existing Redis connection
+ and initialization functionality.
+ """
+
+ def __init__(
+ self,
+ stream_name: str = "scheduler:messages:stream",
+ consumer_group: str = "scheduler_group",
+ consumer_name: str | None = "scheduler_consumer",
+ max_len: int = 10000,
+ maxsize: int = 0, # For Queue compatibility
+ auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages
+ ):
+ """
+ Initialize the Redis queue.
+
+ Args:
+ stream_name: Name of the Redis stream
+ consumer_group: Name of the consumer group
+ consumer_name: Name of the consumer (auto-generated if None)
+ max_len: Maximum length of the stream (for memory management)
+ maxsize: Maximum size of the queue (for Queue compatibility, ignored)
+ auto_delete_acked: Whether to automatically delete acknowledged messages from stream
+ """
+ super().__init__()
+
+ # If maxsize <= 0, set to None (unlimited queue size)
+ if maxsize <= 0:
+ maxsize = 0
+
+ # Stream configuration
+ self.stream_name = stream_name
+ self.consumer_group = consumer_group
+ self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}"
+ self.max_len = max_len
+ self.maxsize = maxsize # For Queue compatibility
+ self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages
+
+ # Consumer state
+ self._is_listening = False
+ self._message_handler: Callable[[ScheduleMessageItem], None] | None = None
+
+ # Connection state
+ self._is_connected = False
+
+ # Task tracking for mem_scheduler_wait compatibility
+ self._unfinished_tasks = 0
+
+ # Auto-initialize Redis connection
+ if self.auto_initialize_redis():
+ self._is_connected = True
+ self._ensure_consumer_group()
+
+ def _ensure_consumer_group(self) -> None:
+ """Ensure the consumer group exists for the stream."""
+ if not self._redis_conn:
+ return
+
+ try:
+ self._redis_conn.xgroup_create(
+ self.stream_name, self.consumer_group, id="0", mkstream=True
+ )
+ logger.debug(
+ f"Created consumer group '{self.consumer_group}' for stream '{self.stream_name}'"
+ )
+ except Exception as e:
+ # Check if it's a "consumer group already exists" error
+ error_msg = str(e).lower()
+ if "busygroup" in error_msg or "already exists" in error_msg:
+ logger.info(
+ f"Consumer group '{self.consumer_group}' already exists for stream '{self.stream_name}'"
+ )
+ else:
+ logger.error(f"Error creating consumer group: {e}", exc_info=True)
+
+ def put(
+ self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None
+ ) -> None:
+ """
+ Add a message to the Redis queue (Queue-compatible interface).
+
+ Args:
+ message: SchedulerMessageItem to add to the queue
+ block: Ignored for Redis implementation (always non-blocking)
+ timeout: Ignored for Redis implementation
+
+ Raises:
+ ConnectionError: If not connected to Redis
+ TypeError: If message is not a ScheduleMessageItem
+ """
+ if not self._redis_conn:
+ raise ConnectionError("Not connected to Redis. Redis connection not available.")
+
+ if not isinstance(message, ScheduleMessageItem):
+ raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}")
+
+ try:
+ # Convert message to dictionary for Redis storage
+ message_data = message.to_dict()
+
+ # Add to Redis stream with automatic trimming
+ message_id = self._redis_conn.xadd(
+ self.stream_name, message_data, maxlen=self.max_len, approximate=True
+ )
+
+ logger.info(
+ f"Added message {message_id} to Redis stream: {message.label} - {message.content[:100]}..."
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to add message to Redis queue: {e}")
+ raise
+
+ def put_nowait(self, message: ScheduleMessageItem) -> None:
+ """
+ Add a message to the Redis queue without blocking (Queue-compatible interface).
+
+ Args:
+ message: SchedulerMessageItem to add to the queue
+ """
+ self.put(message, block=False)
+
+ def ack_message(self, redis_message_id):
+ self.redis.xack(self.stream_name, self.consumer_group, redis_message_id)
+
+ # Optionally delete the message from the stream to keep it clean
+ if self.auto_delete_acked:
+ try:
+ self._redis_conn.xdel(self.stream_name, redis_message_id)
+ logger.info(f"Successfully delete acknowledged message {redis_message_id}")
+ except Exception as e:
+ logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}")
+
+ def get(
+ self,
+ block: bool = True,
+ timeout: float | None = None,
+ batch_size: int | None = None,
+ ) -> list[ScheduleMessageItem]:
+ if not self._redis_conn:
+ raise ConnectionError("Not connected to Redis. Redis connection not available.")
+
+ try:
+ # Calculate timeout for Redis
+ redis_timeout = None
+ if block and timeout is not None:
+ redis_timeout = int(timeout * 1000)
+ elif not block:
+ redis_timeout = None # Non-blocking
+
+ # Read messages from the consumer group
+ try:
+ messages = self._redis_conn.xreadgroup(
+ self.consumer_group,
+ self.consumer_name,
+ {self.stream_name: ">"},
+ count=batch_size if not batch_size else 1,
+ block=redis_timeout,
+ )
+ except Exception as read_err:
+ # Handle missing group/stream by creating and retrying once
+ err_msg = str(read_err).lower()
+ if "nogroup" in err_msg or "no such key" in err_msg:
+ logger.warning(
+ f"Consumer group or stream missing for '{self.stream_name}/{self.consumer_group}'. Attempting to create and retry."
+ )
+ messages = self._redis_conn.xreadgroup(
+ self.consumer_group,
+ self.consumer_name,
+ {self.stream_name: ">"},
+ count=batch_size if not batch_size else 1,
+ block=redis_timeout,
+ )
+ else:
+ raise
+ result_messages = []
+
+ for _stream, stream_messages in messages:
+ for message_id, fields in stream_messages:
+ try:
+ # Convert Redis message back to SchedulerMessageItem
+ message = ScheduleMessageItem.from_dict(fields)
+ message.redis_message_id = message_id
+
+ result_messages.append(message)
+
+ except Exception as e:
+ logger.error(f"Failed to parse message {message_id}: {e}")
+
+ # Always return a list for consistency
+ if not result_messages:
+ if not block:
+ return [] # Return empty list for non-blocking calls
+ else:
+ # If no messages were found, raise Empty exception
+ from queue import Empty
+
+ raise Empty("No messages available in Redis queue")
+
+ return result_messages if batch_size is not None else result_messages[0]
+
+ except Exception as e:
+ if "Empty" in str(type(e).__name__):
+ raise
+ logger.error(f"Failed to get message from Redis queue: {e}")
+ raise
+
+ def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]:
+ """
+ Get messages from the Redis queue without blocking (Queue-compatible interface).
+
+ Returns:
+ List of SchedulerMessageItem objects
+
+ Raises:
+ Empty: If no message is available
+ """
+ return self.get(block=False, batch_size=batch_size)
+
+ def qsize(self) -> int:
+ """
+ Get the current size of the Redis queue (Queue-compatible interface).
+
+ Returns the number of pending (unacknowledged) messages in the consumer group,
+ which represents the actual queue size for processing.
+
+ Returns:
+ Number of pending messages in the queue
+ """
+ if not self._redis_conn:
+ return 0
+
+ try:
+ # Get pending messages info for the consumer group
+ # XPENDING returns info about pending messages that haven't been acknowledged
+ pending_info = self._redis_conn.xpending(self.stream_name, self.consumer_group)
+
+ # pending_info[0] contains the count of pending messages
+ if pending_info and len(pending_info) > 0 and pending_info[0] is not None:
+ pending_count = int(pending_info[0])
+ if pending_count > 0:
+ return pending_count
+
+ # If no pending messages, check if there are new messages in the stream
+ # that haven't been read by any consumer yet
+ try:
+ # Get the last delivered ID for the consumer group
+ groups_info = self._redis_conn.xinfo_groups(self.stream_name)
+ if not groups_info:
+ # No groups exist, check total stream length
+ return self._redis_conn.xlen(self.stream_name) or 0
+
+ last_delivered_id = "0-0"
+
+ for group_info in groups_info:
+ if group_info and group_info.get("name") == self.consumer_group:
+ last_delivered_id = group_info.get("last-delivered-id", "0-0")
+ break
+
+ # Count messages after the last delivered ID
+ new_messages = self._redis_conn.xrange(
+ self.stream_name,
+ f"({last_delivered_id}", # Exclusive start
+ "+", # End at the latest message
+ count=1000, # Limit to avoid memory issues
+ )
+
+ return len(new_messages) if new_messages else 0
+
+ except Exception as inner_e:
+ logger.debug(f"Failed to get new messages count: {inner_e}")
+ # Fallback: return stream length
+ try:
+ stream_len = self._redis_conn.xlen(self.stream_name)
+ return stream_len if stream_len is not None else 0
+ except Exception:
+ return 0
+
+ except Exception as e:
+ logger.debug(f"Failed to get Redis queue size via XPENDING: {e}")
+ # Fallback to stream length if pending check fails
+ try:
+ stream_len = self._redis_conn.xlen(self.stream_name)
+ return stream_len if stream_len is not None else 0
+ except Exception as fallback_e:
+ logger.error(f"Failed to get Redis queue size (all methods failed): {fallback_e}")
+ return 0
+
+ def size(self) -> int:
+ """
+ Get the current size of the Redis queue (alias for qsize).
+
+ Returns:
+ Number of messages in the queue
+ """
+ return self.qsize()
+
+ def empty(self) -> bool:
+ """
+ Check if the Redis queue is empty (Queue-compatible interface).
+
+ Returns:
+ True if the queue is empty, False otherwise
+ """
+ return self.qsize() == 0
+
+ def full(self) -> bool:
+ """
+ Check if the Redis queue is full (Queue-compatible interface).
+
+ For Redis streams, we consider the queue full if it exceeds maxsize.
+ If maxsize is 0 or None, the queue is never considered full.
+
+ Returns:
+ True if the queue is full, False otherwise
+ """
+ if self.maxsize <= 0:
+ return False
+ return self.qsize() >= self.maxsize
+
+ def join(self) -> None:
+ """
+ Block until all items in the queue have been gotten and processed (Queue-compatible interface).
+
+ For Redis streams, this would require tracking pending messages,
+ which is complex. For now, this is a no-op.
+ """
+
+ def clear(self) -> None:
+ """Clear all messages from the queue."""
+ if not self._is_connected or not self._redis_conn:
+ return
+
+ try:
+ # Delete the entire stream
+ self._redis_conn.delete(self.stream_name)
+ logger.info(f"Cleared Redis stream: {self.stream_name}")
+
+ # Recreate the consumer group
+ self._ensure_consumer_group()
+ except Exception as e:
+ logger.error(f"Failed to clear Redis queue: {e}")
+
+ def start_listening(
+ self,
+ handler: Callable[[ScheduleMessageItem], None],
+ batch_size: int = 10,
+ poll_interval: float = 0.1,
+ ) -> None:
+ """
+ Start listening for messages and process them with the provided handler.
+
+ Args:
+ handler: Function to call for each received message
+ batch_size: Number of messages to process in each batch
+ poll_interval: Interval between polling attempts in seconds
+ """
+ if not self._is_connected:
+ raise ConnectionError("Not connected to Redis. Call connect() first.")
+
+ self._message_handler = handler
+ self._is_listening = True
+
+ logger.info(f"Started listening on Redis stream: {self.stream_name}")
+
+ try:
+ while self._is_listening:
+ messages = self.get(timeout=poll_interval, count=batch_size)
+
+ for message in messages:
+ try:
+ self._message_handler(message)
+ except Exception as e:
+ logger.error(f"Error processing message {message.item_id}: {e}")
+
+ # Small sleep to prevent excessive CPU usage
+ if not messages:
+ time.sleep(poll_interval)
+
+ except KeyboardInterrupt:
+ logger.info("Received interrupt signal, stopping listener")
+ except Exception as e:
+ logger.error(f"Error in message listener: {e}")
+ finally:
+ self._is_listening = False
+ logger.info("Stopped listening for messages")
+
+ def stop_listening(self) -> None:
+ """Stop the message listener."""
+ self._is_listening = False
+ logger.info("Requested stop for message listener")
+
+ def connect(self) -> None:
+ """Establish connection to Redis and set up the queue."""
+ if self._redis_conn is not None:
+ try:
+ # Test the connection
+ self._redis_conn.ping()
+ self._is_connected = True
+ logger.debug("Redis connection established successfully")
+ except Exception as e:
+ logger.error(f"Failed to connect to Redis: {e}")
+ self._is_connected = False
+ else:
+ logger.error("Redis connection not initialized")
+ self._is_connected = False
+
+ def disconnect(self) -> None:
+ """Disconnect from Redis and clean up resources."""
+ self._is_connected = False
+ if self._is_listening:
+ self.stop_listening()
+ logger.debug("Disconnected from Redis")
+
+ def __enter__(self):
+ """Context manager entry."""
+ self.connect()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ self.stop_listening()
+ self.disconnect()
+
+ def __del__(self):
+ """Cleanup when object is destroyed."""
+ if self._is_connected:
+ self.disconnect()
+
+ @property
+ def unfinished_tasks(self) -> int:
+ return self.qsize()
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index 6840adc2b..32fefce63 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -51,7 +51,7 @@ def __init__(self, config: GeneralSchedulerConfig):
def long_memory_update_process(
self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem]
):
- mem_cube = messages[0].mem_cube
+ mem_cube = self.current_mem_cube
# for status update
self._set_current_context_from_message(msg=messages[0])
@@ -140,7 +140,7 @@ def long_memory_update_process(
label=QUERY_LABEL,
user_id=user_id,
mem_cube_id=mem_cube_id,
- mem_cube=messages[0].mem_cube,
+ mem_cube=self.current_mem_cube,
)
def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
@@ -212,7 +212,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True)
userinput_memory_ids = []
- mem_cube = msg.mem_cube
+ mem_cube = self.current_mem_cube
for memory_id in userinput_memory_ids:
try:
mem_item: TextualMemoryItem = mem_cube.text_mem.get(
@@ -234,7 +234,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
memory_type=mem_type,
user_id=msg.user_id,
mem_cube_id=msg.mem_cube_id,
- mem_cube=msg.mem_cube,
+ mem_cube=self.current_mem_cube,
log_func_callback=self._submit_web_logs,
)
@@ -248,7 +248,7 @@ def process_message(message: ScheduleMessageItem):
try:
user_id = message.user_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
+ mem_cube = self.current_mem_cube
content = message.content
user_name = message.user_name
@@ -412,7 +412,7 @@ def process_message(message: ScheduleMessageItem):
try:
user_id = message.user_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
+ mem_cube = self.current_mem_cube
content = message.content
user_name = message.user_name
@@ -516,7 +516,7 @@ def process_message(message: ScheduleMessageItem):
user_id = message.user_id
session_id = message.session_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
+ mem_cube = self.current_mem_cube
content = message.content
messages_list = json.loads(content)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
index e18c6e51a..25b9a98f3 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
@@ -2,7 +2,7 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TextualMemoryItem
@@ -66,7 +66,7 @@ def filter_unrelated_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
relevant_indices = response["relevant_memories"]
filtered_count = response["filtered_count"]
@@ -164,7 +164,7 @@ def filter_redundant_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
kept_indices = response["kept_memories"]
redundant_groups = response.get("redundant_groups", [])
@@ -226,8 +226,6 @@ def filter_unrelated_and_redundant_memories(
Note:
If LLM filtering fails, returns all memories (conservative approach)
"""
- success_flag = False
-
if not memories:
logger.info("No memories to filter for unrelated and redundant - returning empty list")
return [], True
@@ -265,7 +263,7 @@ def filter_unrelated_and_redundant_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
kept_indices = response["kept_memories"]
unrelated_removed_count = response.get("unrelated_removed_count", 0)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
index b766f0010..848b1d257 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
@@ -1,9 +1,14 @@
+from concurrent.futures import as_completed
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.context.context import ContextThreadPoolExecutor
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
TreeTextMemory_FINE_SEARCH_METHOD,
TreeTextMemory_SEARCH_METHOD,
)
@@ -12,11 +17,11 @@
filter_vector_based_similar_memories,
transform_name_to_key,
)
-from memos.mem_scheduler.utils.misc_utils import (
- extract_json_dict,
-)
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
+from memos.memories.textual.item import TextualMemoryMetadata
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+# Extract JSON response
from .memory_filter import MemoryFilter
@@ -30,12 +35,213 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
# hyper-parameters
self.filter_similarity_threshold = 0.75
self.filter_min_length_threshold = 6
-
- self.config: BaseSchedulerConfig = config
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
self.process_llm = process_llm
+ self.config = config
- # Initialize memory filter
- self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+ # Configure enhancement batching & retries from config with safe defaults
+ self.batch_size: int | None = getattr(
+ config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ )
+ self.retries: int = getattr(
+ config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ )
+
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int | None = None
+ ) -> bool:
+ limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
+ # Build prompt using the template
+ prompt = self.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ # Use the process LLM to generate response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ result = extract_json_obj(response)
+
+ # Validate response structure
+ if "result" in result:
+ logger.info(
+ f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}"
+ )
+ return result["result"]
+ else:
+ logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...")
+ # Fallback: return False if we can't determine answer ability
+ return False
+
+ # ---------------------- Enhancement helpers ----------------------
+ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
+ if len(query_history) == 1:
+ query_history = query_history[0]
+ else:
+ query_history = (
+ [f"[{i}] {query}" for i, query in enumerate(query_history)]
+ if len(query_history) > 1
+ else query_history[0]
+ )
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
+ return self.build_prompt(
+ "memory_enhancement",
+ query_history=query_history,
+ memories=text_memories,
+ )
+
+ def _process_enhancement_batch(
+ self,
+ batch_index: int,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ retries: int,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ attempt = 0
+ text_memories = [one.memory for one in memories]
+ while attempt <= max(0, retries) + 1:
+ try:
+ prompt = self._build_enhancement_prompt(
+ query_history=query_history, batch_texts=text_memories
+ )
+ logger.debug(
+ f"[Enhance][batch={batch_index}] Prompt (first 200 chars, len={len(prompt)}): "
+ f"{prompt[:200]}]..."
+ )
+
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ logger.debug(
+ f"[Enhance][batch={batch_index}] Response (first 200 chars): {response}..."
+ )
+
+ processed_text_memories = extract_list_items_in_answer(response)
+ if len(processed_text_memories) == len(memories):
+ # Update
+ for i, new_mem in enumerate(processed_text_memories):
+ memories[i].memory = new_mem
+ enhanced_memories = memories
+ else:
+ # create new
+ enhanced_memories = []
+ user_id = memories[0].metadata.user_id
+ for new_mem in processed_text_memories:
+ enhanced_memories.append(
+ TextualMemoryItem(
+ memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id)
+ )
+ )
+ enhanced_memories = (
+ enhanced_memories + memories[: len(memories) - len(enhanced_memories)]
+ )
+
+ logger.info(
+ f"[Enhance]: processed_text_memories: {len(processed_text_memories)}; padded with original memories to preserve total count"
+ )
+
+ return enhanced_memories, True
+ except Exception as e:
+ attempt += 1
+ logger.debug(
+ f"[Enhance][batch={batch_index}] ๐ retry {attempt}/{max(1, retries) + 1} failed: {e}"
+ )
+ logger.error(
+ f"Fail to run memory enhancement; original memories: {memories}", exc_info=True
+ )
+ return memories, False
+
+ @staticmethod
+ def _split_batches(
+ memories: list[TextualMemoryItem], batch_size: int
+ ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
+ batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
+ start = 0
+ n = len(memories)
+ while start < n:
+ end = min(start + batch_size, n)
+ batches.append((start, end, memories[start:end]))
+ start = end
+ return batches
+
+ def enhance_memories_with_query(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Enhance memories by adding context and making connections to better answer queries.
+
+ Args:
+ query_history: List of user queries in chronological order
+ memories: List of memory items to enhance
+
+ Returns:
+ Tuple of (enhanced_memories, success_flag)
+ """
+ if not memories:
+ logger.warning("[Enhance] โ ๏ธ skipped (no memories to process)")
+ return memories, True
+
+ batch_size = self.batch_size
+ retries = self.retries
+ num_of_memories = len(memories)
+ try:
+ # no parallel
+ if batch_size is None or num_of_memories <= batch_size:
+ # Single batch path with retry
+ enhanced_memories, success_flag = self._process_enhancement_batch(
+ batch_index=0,
+ query_history=query_history,
+ memories=memories,
+ retries=retries,
+ )
+
+ all_success = success_flag
+ else:
+ # parallel running batches
+ # Split into batches preserving order
+ batches = self._split_batches(memories=memories, batch_size=batch_size)
+
+ # Process batches concurrently
+ all_success = True
+ failed_batches = 0
+ with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
+ future_map = {
+ executor.submit(
+ self._process_enhancement_batch, bi, query_history, texts, retries
+ ): (bi, s, e)
+ for bi, (s, e, texts) in enumerate(batches)
+ }
+ enhanced_memories = []
+ for fut in as_completed(future_map):
+ bi, s, e = future_map[fut]
+
+ batch_memories, ok = fut.result()
+ enhanced_memories.extend(batch_memories)
+ if not ok:
+ all_success = False
+ failed_batches += 1
+ logger.info(
+ f"[Enhance] โ
multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |"
+ f" failed_batches={failed_batches} | success={all_success}"
+ )
+
+ except Exception as e:
+ logger.error(f"[Enhance] โ fatal error: {e}", exc_info=True)
+ all_success = False
+ enhanced_memories = memories
+
+ if len(enhanced_memories) == 0:
+ enhanced_memories = memories
+ logger.error("[Enhance] โ fatal error: enhanced_memories is empty", exc_info=True)
+ return enhanced_memories, all_success
def search(
self,
@@ -115,7 +321,7 @@ def rerank_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
new_order = response["new_order"][:top_k]
text_memories_with_new_order = [original_memories[idx] for idx in new_order]
logger.info(
diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
index 46c4e2d49..99982d2e6 100644
--- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
+++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
@@ -11,6 +11,7 @@
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL,
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES,
+ DEFAULT_STOP_WAIT,
DEFAULT_STUCK_THREAD_TOLERANCE,
)
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -46,6 +47,11 @@ def __init__(self, config: BaseSchedulerConfig):
self.dispatcher: SchedulerDispatcher | None = None
self.dispatcher_pool_name = "dispatcher"
+ # Configure shutdown wait behavior from config or default
+ self.stop_wait = (
+ self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT
+ )
+
def initialize(self, dispatcher: SchedulerDispatcher):
self.dispatcher = dispatcher
self.register_pool(
@@ -367,12 +373,9 @@ def stop(self) -> None:
if not executor._shutdown: # pylint: disable=protected-access
try:
logger.info(f"Shutting down thread pool '{name}'")
- executor.shutdown(wait=True, cancel_futures=True)
+ executor.shutdown(wait=self.stop_wait, cancel_futures=True)
logger.info(f"Successfully shut down thread pool '{name}'")
except Exception as e:
logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True)
- # Clear the pool registry
- self._pools.clear()
-
logger.info("Thread pool monitor and all pools stopped")
diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py
index a789d581e..a5f1c0097 100644
--- a/src/memos/mem_scheduler/monitors/general_monitor.py
+++ b/src/memos/mem_scheduler/monitors/general_monitor.py
@@ -29,7 +29,7 @@
QueryMonitorQueue,
)
from memos.mem_scheduler.utils.db_utils import get_utc_now
-from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TreeTextMemory
@@ -92,7 +92,7 @@ def extract_query_keywords(self, query: str) -> list:
llm_response = self._process_llm.generate([{"role": "user", "content": prompt}])
try:
# Parse JSON output from LLM response
- keywords = extract_json_dict(llm_response)
+ keywords = extract_json_obj(llm_response)
assert isinstance(keywords, list)
except Exception as e:
logger.error(
@@ -206,7 +206,7 @@ def update_working_memory_monitors(
self.working_mem_monitor_capacity = min(
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
(
- text_mem_base.memory_manager.memory_size["WorkingMemory"]
+ int(text_mem_base.memory_manager.memory_size["WorkingMemory"])
+ self.partial_retention_number
),
)
@@ -353,7 +353,7 @@ def detect_intent(
)
response = self._process_llm.generate([{"role": "user", "content": prompt}])
try:
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
assert ("trigger_retrieval" in response) and ("missing_evidences" in response)
except Exception:
logger.error(f"Fail to extract json dict from response: {response}")
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index a087ab2df..b62b1e51d 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -2,7 +2,7 @@
import os
from collections import OrderedDict
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from memos.api.product_models import APISearchRequest
from memos.configs.mem_scheduler import GeneralSchedulerConfig
@@ -52,11 +52,24 @@ def __init__(self, config: GeneralSchedulerConfig):
API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer,
}
)
+ self.searcher = None
+ self.reranker = None
+ self.text_mem = None
+
+ def init_mem_cube(self, mem_cube):
+ self.current_mem_cube = mem_cube
+ self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem
+ self.searcher: Searcher = self.text_mem.get_searcher(
+ manual_close_internet=False,
+ moscube=False,
+ )
+ self.reranker: HTTPBGEReranker = self.text_mem.reranker
def submit_memory_history_async_task(
self,
search_req: APISearchRequest,
user_context: UserContext,
+ memories_to_store: dict | None = None,
session_id: str | None = None,
):
# Create message for async fine search
@@ -71,19 +84,16 @@ def submit_memory_history_async_task(
"chat_history": search_req.chat_history,
},
"user_context": {"mem_cube_id": user_context.mem_cube_id},
+ "memories_to_store": memories_to_store,
}
async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}"
- # Get mem_cube for the message
- mem_cube = self.current_mem_cube
-
message = ScheduleMessageItem(
item_id=async_task_id,
user_id=search_req.user_id,
mem_cube_id=user_context.mem_cube_id,
label=API_MIX_SEARCH_LABEL,
- mem_cube=mem_cube,
content=json.dumps(message_content),
timestamp=get_utc_now(),
)
@@ -127,33 +137,26 @@ def mix_search_memories(
self,
search_req: APISearchRequest,
user_context: UserContext,
- ):
+ ) -> list[dict[str, Any]]:
"""
Mix search memories: fast search + async fine search
"""
# Get mem_cube for fast search
- mem_cube = self.current_mem_cube
-
target_session_id = search_req.session_id
if not target_session_id:
target_session_id = "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
- text_mem: TreeTextMemory = mem_cube.text_mem
- searcher: Searcher = text_mem.get_searcher(
- manual_close_internet=not search_req.internet_search,
- moscube=False,
- )
# Rerank Memories - reranker expects TextualMemoryItem objects
- reranker: HTTPBGEReranker = text_mem.reranker
+
info = {
"user_id": search_req.user_id,
"session_id": target_session_id,
"chat_history": search_req.chat_history,
}
- fast_retrieved_memories = searcher.retrieve(
+ fast_retrieved_memories = self.searcher.retrieve(
query=search_req.query,
user_name=user_context.mem_cube_id,
top_k=search_req.top_k,
@@ -164,13 +167,7 @@ def mix_search_memories(
info=info,
)
- self.submit_memory_history_async_task(
- search_req=search_req,
- user_context=user_context,
- session_id=search_req.session_id,
- )
-
- # Try to get pre-computed fine memories if available
+ # Try to get pre-computed memories if available
history_memories = self.api_module.get_history_memories(
user_id=search_req.user_id,
mem_cube_id=user_context.mem_cube_id,
@@ -178,7 +175,7 @@ def mix_search_memories(
)
if not history_memories:
- fast_memories = searcher.post_retrieve(
+ fast_memories = self.searcher.post_retrieve(
retrieved_results=fast_retrieved_memories,
top_k=search_req.top_k,
user_name=user_context.mem_cube_id,
@@ -187,39 +184,72 @@ def mix_search_memories(
# Format fast memories for return
formatted_memories = [format_textual_memory_item(data) for data in fast_memories]
return formatted_memories
+ else:
+ # if history memories can directly answer
+ sorted_history_memories = self.reranker.rerank(
+ query=search_req.query, # Use search_req.query instead of undefined query
+ graph_results=history_memories, # Pass TextualMemoryItem objects directly
+ top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
+ search_filter=search_filter,
+ )
- sorted_history_memories = reranker.rerank(
- query=search_req.query, # Use search_req.query instead of undefined query
- graph_results=history_memories, # Pass TextualMemoryItem objects directly
- top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
- search_filter=search_filter,
- )
+ processed_hist_mem = self.searcher.post_retrieve(
+ retrieved_results=sorted_history_memories,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
- sorted_results = fast_retrieved_memories + sorted_history_memories
- final_results = searcher.post_retrieve(
- retrieved_results=sorted_results,
- top_k=search_req.top_k,
- user_name=user_context.mem_cube_id,
- info=info,
- )
+ can_answer = self.retriever.evaluate_memory_answer_ability(
+ query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem]
+ )
- formatted_memories = [
- format_textual_memory_item(item) for item in final_results[: search_req.top_k]
- ]
+ if can_answer:
+ sorted_results = fast_retrieved_memories + sorted_history_memories
+ combined_results = self.searcher.post_retrieve(
+ retrieved_results=sorted_results,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+ memories = combined_results[: search_req.top_k]
+ formatted_memories = [format_textual_memory_item(item) for item in memories]
+ logger.info("can_answer")
+ else:
+ sorted_results = fast_retrieved_memories + sorted_history_memories
+ combined_results = self.searcher.post_retrieve(
+ retrieved_results=sorted_results,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+ enhanced_results, _ = self.retriever.enhance_memories_with_query(
+ query_history=[search_req.query],
+ memories=combined_results,
+ )
+ memories = enhanced_results[: search_req.top_k]
+ formatted_memories = [format_textual_memory_item(item) for item in memories]
+ logger.info("cannot answer")
+
+ self.submit_memory_history_async_task(
+ search_req=search_req,
+ user_context=user_context,
+ memories_to_store={
+ "memories": [one.to_dict() for one in memories],
+ "formatted_memories": formatted_memories,
+ },
+ )
- return formatted_memories
+ return formatted_memories
def update_search_memories_to_redis(
self,
messages: list[ScheduleMessageItem],
):
- mem_cube: NaiveMemCube = self.current_mem_cube
-
for msg in messages:
content_dict = json.loads(msg.content)
search_req = content_dict["search_req"]
user_context = content_dict["user_context"]
-
session_id = search_req.get("session_id")
if session_id:
if session_id not in self.session_counter:
@@ -237,13 +267,20 @@ def update_search_memories_to_redis(
else:
session_turn = 0
- memories: list[TextualMemoryItem] = self.search_memories(
- search_req=APISearchRequest(**content_dict["search_req"]),
- user_context=UserContext(**content_dict["user_context"]),
- mem_cube=mem_cube,
- mode=SearchMode.FAST,
- )
- formatted_memories = [format_textual_memory_item(data) for data in memories]
+ memories_to_store = content_dict["memories_to_store"]
+ if memories_to_store is None:
+ memories: list[TextualMemoryItem] = self.search_memories(
+ search_req=APISearchRequest(**content_dict["search_req"]),
+ user_context=UserContext(**content_dict["user_context"]),
+ mem_cube=self.current_mem_cube,
+ mode=SearchMode.FAST,
+ )
+ formatted_memories = [format_textual_memory_item(data) for data in memories]
+ else:
+ memories = [
+ TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]
+ ]
+ formatted_memories = memories_to_store["formatted_memories"]
# Sync search data to Redis
self.api_module.sync_search_data(
diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py
index f3d2191f8..7f2c09b7d 100644
--- a/src/memos/mem_scheduler/schemas/general_schemas.py
+++ b/src/memos/mem_scheduler/schemas/general_schemas.py
@@ -6,6 +6,7 @@
class SearchMode(str, Enum):
"""Enumeration for search modes."""
+ NOT_INITIALIZED = "not_initialized"
FAST = "fast"
FINE = "fine"
MIXTURE = "mixture"
@@ -32,14 +33,18 @@ class SearchMode(str, Enum):
DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache"
DEFAULT_THREAD_POOL_MAX_WORKERS = 50
DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05
+DEFAULT_CONSUME_BATCH = 1
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
DEFAULT_STUCK_THREAD_TOLERANCE = 10
-DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 1000000
+DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 0
DEFAULT_TOP_K = 10
DEFAULT_CONTEXT_WINDOW_SIZE = 5
-DEFAULT_USE_REDIS_QUEUE = False
+DEFAULT_USE_REDIS_QUEUE = True
DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30
+DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 10
+DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1
+DEFAULT_STOP_WAIT = False
# startup mode configuration
STARTUP_BY_THREAD = "thread"
@@ -64,6 +69,7 @@ class SearchMode(str, Enum):
MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType"
DEFAULT_MAX_QUERY_KEY_WORDS = 1000
DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05]
+DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50
# new types
diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py
index 7f328474f..f1d48f3f1 100644
--- a/src/memos/mem_scheduler/schemas/message_schemas.py
+++ b/src/memos/mem_scheduler/schemas/message_schemas.py
@@ -2,11 +2,10 @@
from typing import Any
from uuid import uuid4
-from pydantic import BaseModel, ConfigDict, Field, field_serializer
+from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
from memos.log import get_logger
-from memos.mem_cube.base import BaseMemCube
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -34,22 +33,19 @@
class ScheduleMessageItem(BaseModel, DictConversionMixin):
item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4()))
+ redis_message_id: str = Field(default="", description="the message get from redis stream")
user_id: str = Field(..., description="user id")
mem_cube_id: str = Field(..., description="memcube id")
+ session_id: str = Field(default="", description="Session ID for soft-filtering memories")
label: str = Field(..., description="Label of the schedule message")
- mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule")
content: str = Field(..., description="Content of the schedule message")
timestamp: datetime = Field(
default_factory=get_utc_now, description="submit time for schedule_messages"
)
- user_name: str | None = Field(
- default=None,
+ user_name: str = Field(
+ default="",
description="user name / display name (optional)",
)
- session_id: str | None = Field(
- default=None,
- description="session_id (optional)",
- )
# Pydantic V2 model configuration
model_config = ConfigDict(
@@ -65,7 +61,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
"user_id": "user123", # Example user identifier
"mem_cube_id": "cube456", # Sample memory cube ID
"label": "sample_label", # Demonstration label value
- "mem_cube": "obj of GeneralMemCube", # Added mem_cube example
"content": "sample content", # Example message content
"timestamp": "2024-07-22T12:00:00Z", # Added timestamp example
"user_name": "Alice", # Added username example
@@ -73,13 +68,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
},
)
- @field_serializer("mem_cube")
- def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str:
- """Custom serializer for BaseMemCube objects to string representation"""
- if isinstance(cube, str):
- return cube
- return f"<{type(cube).__name__}:{id(cube)}>"
-
def to_dict(self) -> dict:
"""Convert model to dictionary suitable for Redis Stream"""
return {
@@ -101,7 +89,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem":
user_id=data["user_id"],
mem_cube_id=data["cube_id"],
label=data["label"],
- mem_cube="Not Applicable", # Custom cube deserialization
content=data["content"],
timestamp=datetime.fromisoformat(data["timestamp"]),
user_name=data.get("user_name"),
diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py
index 5155c98b3..45abc5b36 100644
--- a/src/memos/mem_scheduler/utils/metrics.py
+++ b/src/memos/mem_scheduler/utils/metrics.py
@@ -6,10 +6,14 @@
from dataclasses import dataclass, field
+from memos.log import get_logger
+
# ==== global window config ====
WINDOW_SEC = 120 # 2 minutes sliding window
+logger = get_logger(__name__)
+
# ---------- O(1) EWMA ----------
class Ewma:
@@ -187,7 +191,7 @@ def on_enqueue(
old_lam = ls.lambda_ewma.value_at(now)
ls.lambda_ewma.update(inst_rate, now)
new_lam = ls.lambda_ewma.value_at(now)
- print(
+ logger.info(
f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else 'โ'}s inst={inst_rate:.3f} ฮป {old_lam:.3f}โ{new_lam:.3f}"
)
self._label_topk[label].add(mem_cube_id)
@@ -225,7 +229,7 @@ def on_done(
old_mu = ls.mu_ewma.value_at(now)
ls.mu_ewma.update(inst_rate, now)
new_mu = ls.mu_ewma.value_at(now)
- print(
+ logger.info(
f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else 'โ'}s inst={inst_rate:.3f} ฮผ {old_mu:.3f}โ{new_mu:.3f}"
)
ds = self._detail_stats.get((label, mem_cube_id))
diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py
index aa9b5c489..cce1286bb 100644
--- a/src/memos/mem_scheduler/utils/misc_utils.py
+++ b/src/memos/mem_scheduler/utils/misc_utils.py
@@ -1,5 +1,6 @@
import json
import re
+import traceback
from functools import wraps
from pathlib import Path
@@ -12,7 +13,7 @@
logger = get_logger(__name__)
-def extract_json_dict(text: str):
+def extract_json_obj(text: str):
"""
Safely extracts JSON from LLM response text with robust error handling.
@@ -40,7 +41,7 @@ def extract_json_dict(text: str):
try:
return json.loads(text.strip())
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
# Fallback 1: Extract JSON using regex
json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]"
@@ -49,7 +50,7 @@ def extract_json_dict(text: str):
try:
return json.loads(matches[0])
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
# Fallback 2: Handle malformed JSON (common LLM issues)
try:
@@ -57,10 +58,125 @@ def extract_json_dict(text: str):
text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text)
return json.loads(text)
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}")
+ logger.error("Full traceback:\n" + traceback.format_exc())
raise ValueError(text) from e
+def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]:
+ """
+ Extract bullet list items from LLM output where each item is on a single line
+ starting with a given bullet prefix (default: "- ").
+
+ This function is designed to be robust to common LLM formatting variations,
+ following similar normalization practices as `extract_json_obj`.
+
+ Behavior:
+ - Strips common code-fence markers (```json, ```python, ``` etc.).
+ - Collects all lines that start with any of the provided `bullet_prefixes`.
+ - Tolerates the "โข " bullet as a loose fallback.
+ - Unescapes common sequences like "\\n" and "\\t" within items.
+ - If no bullet lines are found, falls back to attempting to parse a JSON array
+ (using `extract_json_obj`) and returns its string elements.
+
+ Args:
+ text: Raw text response from LLM.
+ bullet_prefixes: Tuple of accepted bullet line prefixes.
+
+ Returns:
+ List of extracted items (strings). Returns an empty list if none can be parsed.
+ """
+ if not text:
+ return []
+
+ # Normalize the text similar to extract_json_obj
+ normalized = text.strip()
+ patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
+ for pattern in patterns_to_remove:
+ normalized = normalized.replace(pattern, "")
+ normalized = normalized.replace("\r\n", "\n")
+
+ lines = normalized.splitlines()
+ items: list[str] = []
+ seen: set[str] = set()
+
+ for raw in lines:
+ line = raw.strip()
+ if not line:
+ continue
+
+ matched = False
+ for prefix in bullet_prefixes:
+ if line.startswith(prefix):
+ content = line[len(prefix) :].strip()
+ content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
+ if content and content not in seen:
+ items.append(content)
+ seen.add(content)
+ matched = True
+ break
+
+ if matched:
+ continue
+
+ if items:
+ return items
+ else:
+ logger.error(f"Fail to parse {text}")
+
+ return []
+
+
+def extract_list_items_in_answer(
+ text: str, bullet_prefixes: tuple[str, ...] = ("- ",)
+) -> list[str]:
+ """
+ Extract list items specifically from content enclosed within `...` tags.
+
+ - When one or more `...` blocks are present, concatenates their inner
+ contents with newlines and parses using `extract_list_items`.
+ - When no `` block is found, falls back to parsing the entire input with
+ `extract_list_items`.
+ - Case-insensitive matching of the `` tag.
+
+ Args:
+ text: Raw text that may contain `...` blocks.
+ bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`).
+
+ Returns:
+ List of extracted items (strings), or an empty list when nothing is parseable.
+ """
+ if not text:
+ return []
+
+ try:
+ normalized = text.strip().replace("\r\n", "\n")
+ # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER
+ tag_variants = ["answer", "Answer", "ANSWER"]
+ matches: list[str] = []
+ for tag in tag_variants:
+ matches = re.findall(rf"<{tag}>([\\s\\S]*?){tag}>", normalized)
+ if matches:
+ break
+ # Fallback: case-insensitive matching if none of the exact-case variants matched
+ if not matches:
+ matches = re.findall(r"([\\s\\S]*?)", normalized, flags=re.IGNORECASE)
+
+ if matches:
+ combined = "\n".join(m.strip() for m in matches if m is not None)
+ return extract_list_items(combined, bullet_prefixes=bullet_prefixes)
+
+ # Fallback: parse the whole text if tags are absent
+ return extract_list_items(normalized, bullet_prefixes=bullet_prefixes)
+ except Exception as e:
+ logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True)
+ # Final fallback: attempt direct list extraction
+ try:
+ return extract_list_items(text, bullet_prefixes=bullet_prefixes)
+ except Exception:
+ return []
+
+
def parse_yaml(yaml_file: str | Path):
yaml_path = Path(yaml_file)
if not yaml_path.is_file():
diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py
index 5439af9c6..e79553f33 100644
--- a/src/memos/mem_scheduler/webservice_modules/redis_service.py
+++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py
@@ -333,6 +333,15 @@ def redis_start_listening(self, handler: Callable | None = None):
logger.warning("Listener is already running")
return
+ # Check Redis connection before starting listener
+ if self.redis is None:
+ logger.warning(
+ "Redis connection is None, attempting to auto-initialize before starting listener..."
+ )
+ if not self.auto_initialize_redis():
+ logger.error("Failed to initialize Redis connection, cannot start listener")
+ return
+
if handler is None:
handler = self.redis_consume_message_stream
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
index 55e33494c..b9814f079 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
@@ -22,6 +22,7 @@ class TaskGoalParser:
def __init__(self, llm=BaseLLM):
self.llm = llm
self.tokenizer = FastTokenizer()
+ self.retries = 1
def parse(
self,
@@ -103,18 +104,24 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
"""
Parse LLM JSON output safely.
"""
- try:
- context = kwargs.get("context", "")
- response = response.replace("```", "").replace("json", "").strip()
- response_json = eval(response)
- return ParsedTaskGoal(
- memories=response_json.get("memories", []),
- keys=response_json.get("keys", []),
- tags=response_json.get("tags", []),
- rephrased_query=response_json.get("rephrased_instruction", None),
- internet_search=response_json.get("internet_search", False),
- goal_type=response_json.get("goal_type", "default"),
- context=context,
- )
- except Exception as e:
- raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e
+ # Ensure at least one attempt
+ attempts = max(1, getattr(self, "retries", 1))
+
+ for attempt_times in range(attempts):
+ try:
+ context = kwargs.get("context", "")
+ response = response.replace("```", "").replace("json", "").strip()
+ response_json = eval(response)
+ return ParsedTaskGoal(
+ memories=response_json.get("memories", []),
+ keys=response_json.get("keys", []),
+ tags=response_json.get("tags", []),
+ rephrased_query=response_json.get("rephrased_instruction", None),
+ internet_search=response_json.get("internet_search", False),
+ goal_type=response_json.get("goal_type", "default"),
+ context=context,
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
+ ) from e
diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py
index b4d091c1f..197a2c1a7 100644
--- a/src/memos/templates/mem_scheduler_prompts.py
+++ b/src/memos/templates/mem_scheduler_prompts.py
@@ -390,6 +390,45 @@
- Focus on whether the memories can fully answer the query without additional information
"""
+MEMORY_ENHANCEMENT_PROMPT = """
+You are a knowledgeable and precise AI assistant.
+
+# GOAL
+Transform each raw memory into an enhanced version that preserves all relevant factual details and makes the information directly useful for answering the user's query.
+
+# CORE PRINCIPLE
+Focus on **relevance** โ the enhanced memories should highlight, clarify, and preserve the information that most directly supports answering the current query.
+
+# RULES & THINKING STEPS
+1. Read the user query carefully and identify what specific facts are needed to answer it.
+2. Go through each memory and:
+ - Keep only details directly relevant to the query (dates, actions, entities, outcomes).
+ - Remove unrelated or background details.
+ - If nothing in a memory relates to the query, delete the entire memory.
+3. Do not add or infer new facts.
+4. Keep facts accurate and phrased clearly.
+5. Each resulting line should stand alone as a usable fact for answering the query.
+
+# OUTPUT FORMAT (STRICT)
+Return ONLY the following block, with **one enhanced memory per line**.
+Each line MUST start with "- " (dash + space).
+
+Wrap the final output inside:
+
+- enhanced memory 1
+- enhanced memory 2
+...
+
+
+## User Query
+{query_history}
+
+## Available Memories
+{memories}
+
+Answer:
+"""
+
PROMPT_MAPPING = {
"intent_recognizing": INTENT_RECOGNIZING_PROMPT,
"memory_reranking": MEMORY_RERANKING_PROMPT,
@@ -398,6 +437,7 @@
"memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT,
"memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT,
"memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT,
+ "memory_enhancement": MEMORY_ENHANCEMENT_PROMPT,
}
MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}"""
diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py
index e3064660b..fc154e013 100644
--- a/tests/mem_scheduler/test_dispatcher.py
+++ b/tests/mem_scheduler/test_dispatcher.py
@@ -90,7 +90,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg1",
user_id="user1",
- mem_cube="cube1",
mem_cube_id="msg1",
label="label1",
content="Test content 1",
@@ -99,7 +98,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg2",
user_id="user1",
- mem_cube="cube1",
mem_cube_id="msg2",
label="label2",
content="Test content 2",
@@ -108,7 +106,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg3",
user_id="user2",
- mem_cube="cube2",
mem_cube_id="msg3",
label="label1",
content="Test content 3",
@@ -193,46 +190,6 @@ def test_dispatch_serial(self):
self.assertEqual(len(label2_messages), 1)
self.assertEqual(label2_messages[0].item_id, "msg2")
- def test_dispatch_parallel(self):
- """Test dispatching messages in parallel mode."""
- # Create fresh mock handlers for this test
- mock_handler1 = MagicMock()
- mock_handler2 = MagicMock()
-
- # Create a new dispatcher for this test to avoid interference
- parallel_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=True)
- parallel_dispatcher.register_handler("label1", mock_handler1)
- parallel_dispatcher.register_handler("label2", mock_handler2)
-
- # Dispatch messages
- parallel_dispatcher.dispatch(self.test_messages)
-
- # Wait for all futures to complete
- parallel_dispatcher.join(timeout=1.0)
-
- # Verify handlers were called - label1 handler should be called twice (for user1 and user2)
- # label2 handler should be called once (only for user1)
- self.assertEqual(mock_handler1.call_count, 2) # Called for user1/msg1 and user2/msg3
- mock_handler2.assert_called_once() # Called for user1/msg2
-
- # Check that each handler received the correct messages
- # For label1: should have two calls, each with one message
- label1_calls = mock_handler1.call_args_list
- self.assertEqual(len(label1_calls), 2)
-
- # Extract messages from calls
- call1_messages = label1_calls[0][0][0] # First call, first argument (messages list)
- call2_messages = label1_calls[1][0][0] # Second call, first argument (messages list)
-
- # Verify the messages in each call
- self.assertEqual(len(call1_messages), 1)
- self.assertEqual(len(call2_messages), 1)
-
- # For label2: should have one call with [msg2]
- label2_messages = mock_handler2.call_args[0][0]
- self.assertEqual(len(label2_messages), 1)
- self.assertEqual(label2_messages[0].item_id, "msg2")
-
def test_group_messages_by_user_and_mem_cube(self):
"""Test grouping messages by user and cube."""
# Check actual grouping logic
diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py
index 03a8e4318..fed1e8500 100644
--- a/tests/mem_scheduler/test_scheduler.py
+++ b/tests/mem_scheduler/test_scheduler.py
@@ -1,7 +1,6 @@
import sys
import unittest
-from contextlib import suppress
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
@@ -21,12 +20,9 @@
from memos.mem_scheduler.schemas.general_schemas import (
ANSWER_LABEL,
QUERY_LABEL,
- STARTUP_BY_PROCESS,
- STARTUP_BY_THREAD,
)
from memos.mem_scheduler.schemas.message_schemas import (
ScheduleLogForWebItem,
- ScheduleMessageItem,
)
from memos.memories.textual.tree import TreeTextMemory
@@ -182,124 +178,6 @@ def test_submit_web_logs(self):
self.assertTrue(hasattr(actual_message, "timestamp"))
self.assertTrue(isinstance(actual_message.timestamp, datetime))
- def test_scheduler_startup_mode_default(self):
- """Test that scheduler has default startup mode set to thread."""
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_THREAD)
-
- def test_scheduler_startup_mode_thread(self):
- """Test scheduler with thread startup mode."""
- # Set scheduler startup mode to thread
- self.scheduler.scheduler_startup_mode = STARTUP_BY_THREAD
-
- # Start the scheduler
- self.scheduler.start()
-
- # Verify that consumer thread is created and process is None
- self.assertIsNotNone(self.scheduler._consumer_thread)
- self.assertIsNone(self.scheduler._consumer_process)
- self.assertTrue(self.scheduler._running)
-
- # Stop the scheduler
- self.scheduler.stop()
-
- def test_redis_message_queue(self):
- """Test Redis message queue functionality for sending and receiving messages."""
- import time
-
- from unittest.mock import MagicMock, patch
-
- # Mock Redis connection and operations
- mock_redis = MagicMock()
- mock_redis.xadd = MagicMock(return_value=b"1234567890-0")
-
- # Track received messages
- received_messages = []
-
- def redis_handler(messages: list[ScheduleMessageItem]) -> None:
- """Handler for Redis messages."""
- received_messages.extend(messages)
-
- # Register Redis handler
- redis_label = "test_redis"
- handlers = {redis_label: redis_handler}
- self.scheduler.register_handlers(handlers)
-
- # Enable Redis queue for this test
- with (
- patch.object(self.scheduler, "use_redis_queue", True),
- patch.object(self.scheduler, "_redis_conn", mock_redis),
- ):
- # Start scheduler
- self.scheduler.start()
-
- # Create test message for Redis
- redis_message = ScheduleMessageItem(
- label=redis_label,
- content="Redis test message",
- user_id="redis_user",
- mem_cube_id="redis_cube",
- mem_cube="redis_mem_cube_obj",
- timestamp=datetime.now(),
- )
-
- # Submit message to Redis queue
- self.scheduler.submit_messages(redis_message)
-
- # Verify Redis xadd was called
- mock_redis.xadd.assert_called_once()
- call_args = mock_redis.xadd.call_args
- self.assertEqual(call_args[0][0], "user:queries:stream")
-
- # Verify message data was serialized correctly
- message_data = call_args[0][1]
- self.assertEqual(message_data["label"], redis_label)
- self.assertEqual(message_data["content"], "Redis test message")
- self.assertEqual(message_data["user_id"], "redis_user")
- self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id
-
- # Simulate Redis message consumption
- # This would normally be handled by the Redis consumer in the scheduler
- time.sleep(0.1) # Brief wait for async operations
-
- # Stop scheduler
- self.scheduler.stop()
-
- print("Redis message queue test completed successfully!")
-
- # Removed test_robustness method - was too time-consuming for CI/CD pipeline
-
- def test_scheduler_startup_mode_process(self):
- """Test scheduler with process startup mode."""
- # Set scheduler startup mode to process
- self.scheduler.scheduler_startup_mode = STARTUP_BY_PROCESS
-
- # Start the scheduler
- try:
- self.scheduler.start()
-
- # Verify that consumer process is created and thread is None
- self.assertIsNotNone(self.scheduler._consumer_process)
- self.assertIsNone(self.scheduler._consumer_thread)
- self.assertTrue(self.scheduler._running)
-
- except Exception as e:
- # Process mode may fail due to pickling issues in test environment
- # This is expected behavior - we just verify the startup mode is set correctly
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS)
- print(f"Process mode test encountered expected pickling issue: {e}")
- finally:
- # Always attempt to stop the scheduler
- with suppress(Exception):
- self.scheduler.stop()
-
- # Verify cleanup attempt was made
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS)
-
- def test_scheduler_startup_mode_constants(self):
- """Test that startup mode constants are properly defined."""
- self.assertEqual(STARTUP_BY_THREAD, "thread")
- self.assertEqual(STARTUP_BY_PROCESS, "process")
-
def test_activation_memory_update(self):
"""Test activation memory update functionality with DynamicCache handling."""
if not self.RUN_ACTIVATION_MEMORY_TESTS:
@@ -401,130 +279,3 @@ def test_dynamic_cache_layers_access(self):
# If layers attribute doesn't exist, verify our fix handles this case
print("โ ๏ธ DynamicCache doesn't have 'layers' attribute in this transformers version")
print("โ
Test passed - our code should handle this gracefully")
-
- def test_get_running_tasks_with_filter(self):
- """Test get_running_tasks method with filter function."""
- # Mock dispatcher and its get_running_tasks method
- mock_task_item1 = MagicMock()
- mock_task_item1.item_id = "task_1"
- mock_task_item1.user_id = "user_1"
- mock_task_item1.mem_cube_id = "cube_1"
- mock_task_item1.task_info = {"type": "query"}
- mock_task_item1.task_name = "test_task_1"
- mock_task_item1.start_time = datetime.now()
- mock_task_item1.end_time = None
- mock_task_item1.status = "running"
- mock_task_item1.result = None
- mock_task_item1.error_message = None
- mock_task_item1.messages = []
-
- # Define a filter function
- def user_filter(task):
- return task.user_id == "user_1"
-
- # Mock the filtered result (only task_1 matches the filter)
- with patch.object(
- self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1}
- ) as mock_get_running_tasks:
- # Call get_running_tasks with filter
- result = self.scheduler.get_running_tasks(filter_func=user_filter)
-
- # Verify result
- self.assertIsInstance(result, dict)
- self.assertIn("task_1", result)
- self.assertEqual(len(result), 1)
-
- # Verify dispatcher method was called with filter
- mock_get_running_tasks.assert_called_once_with(filter_func=user_filter)
-
- def test_get_running_tasks_empty_result(self):
- """Test get_running_tasks method when no tasks are running."""
- # Mock dispatcher to return empty dict
- with patch.object(
- self.scheduler.dispatcher, "get_running_tasks", return_value={}
- ) as mock_get_running_tasks:
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify empty result
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 0)
-
- # Verify dispatcher method was called
- mock_get_running_tasks.assert_called_once_with(filter_func=None)
-
- def test_get_running_tasks_no_dispatcher(self):
- """Test get_running_tasks method when dispatcher is None."""
- # Temporarily set dispatcher to None
- original_dispatcher = self.scheduler.dispatcher
- self.scheduler.dispatcher = None
-
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify empty result and warning behavior
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 0)
-
- # Restore dispatcher
- self.scheduler.dispatcher = original_dispatcher
-
- def test_get_running_tasks_multiple_tasks(self):
- """Test get_running_tasks method with multiple tasks."""
- # Mock multiple task items
- mock_task_item1 = MagicMock()
- mock_task_item1.item_id = "task_1"
- mock_task_item1.user_id = "user_1"
- mock_task_item1.mem_cube_id = "cube_1"
- mock_task_item1.task_info = {"type": "query"}
- mock_task_item1.task_name = "test_task_1"
- mock_task_item1.start_time = datetime.now()
- mock_task_item1.end_time = None
- mock_task_item1.status = "running"
- mock_task_item1.result = None
- mock_task_item1.error_message = None
- mock_task_item1.messages = []
-
- mock_task_item2 = MagicMock()
- mock_task_item2.item_id = "task_2"
- mock_task_item2.user_id = "user_2"
- mock_task_item2.mem_cube_id = "cube_2"
- mock_task_item2.task_info = {"type": "answer"}
- mock_task_item2.task_name = "test_task_2"
- mock_task_item2.start_time = datetime.now()
- mock_task_item2.end_time = None
- mock_task_item2.status = "completed"
- mock_task_item2.result = "success"
- mock_task_item2.error_message = None
- mock_task_item2.messages = ["message1", "message2"]
-
- with patch.object(
- self.scheduler.dispatcher,
- "get_running_tasks",
- return_value={"task_1": mock_task_item1, "task_2": mock_task_item2},
- ) as mock_get_running_tasks:
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify result structure
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 2)
- self.assertIn("task_1", result)
- self.assertIn("task_2", result)
-
- # Verify task_1 details
- task1_dict = result["task_1"]
- self.assertEqual(task1_dict["item_id"], "task_1")
- self.assertEqual(task1_dict["user_id"], "user_1")
- self.assertEqual(task1_dict["status"], "running")
-
- # Verify task_2 details
- task2_dict = result["task_2"]
- self.assertEqual(task2_dict["item_id"], "task_2")
- self.assertEqual(task2_dict["user_id"], "user_2")
- self.assertEqual(task2_dict["status"], "completed")
- self.assertEqual(task2_dict["result"], "success")
- self.assertEqual(task2_dict["messages"], ["message1", "message2"])
-
- # Verify dispatcher method was called
- mock_get_running_tasks.assert_called_once_with(filter_func=None)