Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions src/memos/mem_os/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,16 @@ def _get_all_documents(self, path: str) -> list[str]:
documents.append(str(file_path))
return documents

def chat(self, query: str, user_id: str | None = None) -> str:
def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = None) -> str:
"""
Chat with the MOS.

Args:
query (str): The user's query.
user_id (str, optional): The user ID for the chat session. Defaults to the user ID from the config.
base_prompt (str, optional): A custom base prompt to use for the chat.
It can be a template string with a `{memories}` placeholder.
If not provided, a default prompt is used.

Returns:
str: The response from the MOS.
Expand Down Expand Up @@ -251,9 +255,9 @@ def chat(self, query: str, user_id: str | None = None) -> str:
memories = mem_cube.text_mem.search(query, top_k=self.config.top_k)
memories_all.extend(memories)
logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
system_prompt = self._build_system_prompt(memories_all)
system_prompt = self._build_system_prompt(memories_all, base_prompt=base_prompt)
else:
system_prompt = self._build_system_prompt()
system_prompt = self._build_system_prompt(base_prompt=base_prompt)
current_messages = [
{"role": "system", "content": system_prompt},
*chat_history.chat_history,
Expand Down Expand Up @@ -302,27 +306,38 @@ def chat(self, query: str, user_id: str | None = None) -> str:
return response

def _build_system_prompt(
self, memories: list[TextualMemoryItem] | list[str] | None = None
self,
memories: list[TextualMemoryItem] | list[str] | None = None,
base_prompt: str | None = None,
) -> str:
"""Build system prompt with optional memories context."""
base_prompt = (
"You are a knowledgeable and helpful AI assistant. "
"You have access to conversation memories that help you provide more personalized responses. "
"Use the memories to understand the user's context, preferences, and past interactions. "
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
)
if base_prompt is None:
base_prompt = (
"You are a knowledgeable and helpful AI assistant. "
"You have access to conversation memories that help you provide more personalized responses. "
"Use the memories to understand the user's context, preferences, and past interactions. "
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
)

memory_context = ""
if memories:
memory_context = "\n\n## Memories:\n"
memory_list = []
for i, memory in enumerate(memories, 1):
if isinstance(memory, TextualMemoryItem):
text_memory = memory.memory
else:
if not isinstance(memory, str):
logger.error("Unexpected memory type.")
text_memory = memory
memory_context += f"{i}. {text_memory}\n"
return base_prompt + memory_context
memory_list.append(f"{i}. {text_memory}")
memory_context = "\n".join(memory_list)

if "{memories}" in base_prompt:
return base_prompt.format(memories=memory_context)
elif memories:
# For backward compatibility, append memories if no placeholder is found
memory_context_with_header = "\n\n## Memories:\n" + memory_context
return base_prompt + memory_context_with_header
return base_prompt

def _str_memories(
Expand Down
29 changes: 19 additions & 10 deletions src/memos/mem_os/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def __init__(self, config: MOSConfig):
logger.info(PRO_MODE_WELCOME_MESSAGE)
super().__init__(config)

def chat(self, query: str, user_id: str | None = None) -> str:
def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = None) -> str:
"""
Enhanced chat method with optional CoT (Chain of Thought) enhancement.

Args:
query (str): The user's query.
user_id (str, optional): User ID for context.
base_prompt (str, optional): A custom base prompt to use for the chat.
It can be a template string with a `{memories}` placeholder.
If not provided, a default prompt is used.

Returns:
str: The response from the MOS.
Expand All @@ -46,12 +49,14 @@ def chat(self, query: str, user_id: str | None = None) -> str:

if not self.enable_cot:
# Use the original chat method from core
return super().chat(query, user_id)
return super().chat(query, user_id, base_prompt=base_prompt)

# Enhanced chat with CoT decomposition
return self._chat_with_cot_enhancement(query, user_id)
return self._chat_with_cot_enhancement(query, user_id, base_prompt=base_prompt)

def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) -> str:
def _chat_with_cot_enhancement(
self, query: str, user_id: str | None = None, base_prompt: str | None = None
) -> str:
"""
Chat with CoT enhancement for complex query decomposition.
This method includes all the same validation and processing logic as the core chat method.
Expand Down Expand Up @@ -84,7 +89,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
# Check if the query is complex and needs decomposition
if not decomposition_result.get("is_complex", False):
logger.info("🔍 [CoT] Query is not complex, using standard chat")
return super().chat(query, user_id)
return super().chat(query, user_id, base_prompt=base_prompt)

sub_questions = decomposition_result.get("sub_questions", [])
logger.info(f"🔍 [CoT] Decomposed into {len(sub_questions)} sub-questions")
Expand All @@ -93,7 +98,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
search_engine = self._get_search_engine_for_cot_with_validation(user_cube_ids)
if not search_engine:
logger.warning("🔍 [CoT] No search engine available, using standard chat")
return super().chat(query, user_id)
return super().chat(query, user_id, base_prompt=base_prompt)

# Step 4: Get answers for sub-questions
logger.info("🔍 [CoT] Getting answers for sub-questions...")
Expand All @@ -115,6 +120,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
chat_history=chat_history,
user_id=target_user_id,
search_engine=search_engine,
base_prompt=base_prompt,
)

# Step 6: Update chat history (same as core method)
Expand Down Expand Up @@ -149,7 +155,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
except Exception as e:
logger.error(f"🔍 [CoT] Error in CoT enhancement: {e}")
logger.info("🔍 [CoT] Falling back to standard chat")
return super().chat(query, user_id)
return super().chat(query, user_id, base_prompt=base_prompt)

def _get_search_engine_for_cot_with_validation(
self, user_cube_ids: list[str]
Expand Down Expand Up @@ -183,6 +189,7 @@ def _generate_enhanced_response_with_context(
chat_history: Any,
user_id: str | None = None,
search_engine: BaseTextMemory | None = None,
base_prompt: str | None = None,
) -> str:
"""
Generate an enhanced response using sub-questions and their answers, with chat context.
Expand All @@ -193,6 +200,8 @@ def _generate_enhanced_response_with_context(
sub_answers (list[str]): List of answers to sub-questions.
chat_history: The user's chat history.
user_id (str, optional): User ID for context.
search_engine (BaseTextMemory, optional): Search engine for context retrieval.
base_prompt (str, optional): A custom base prompt for the chat.

Returns:
str: The enhanced response.
Expand All @@ -213,10 +222,10 @@ def _generate_enhanced_response_with_context(
original_query, top_k=self.config.top_k, mode="fast"
)
system_prompt = self._build_system_prompt(
search_memories
search_memories, base_prompt=base_prompt
) # Use the same system prompt builder
else:
system_prompt = self._build_system_prompt()
system_prompt = self._build_system_prompt(base_prompt=base_prompt)
current_messages = [
{"role": "system", "content": system_prompt + SYNTHESIS_PROMPT.format(qa_text=qa_text)},
*chat_history.chat_history,
Expand Down Expand Up @@ -261,7 +270,7 @@ def _generate_enhanced_response_with_context(
except Exception as e:
logger.error(f"🔍 [CoT] Error generating enhanced response: {e}")
# Fallback to standard chat
return super().chat(original_query, user_id)
return super().chat(original_query, user_id, base_prompt=base_prompt)

@classmethod
def cot_decompose(
Expand Down
73 changes: 73 additions & 0 deletions tests/mem_os/test_memos.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,76 @@ def test_mos_has_core_methods(mock_llm, mock_reader, mock_user_manager, simple_c
assert callable(mos.chat)
assert callable(mos.search)
assert callable(mos.add)


@patch("memos.mem_os.core.UserManager")
@patch("memos.mem_os.core.MemReaderFactory")
@patch("memos.mem_os.core.LLMFactory")
@patch("memos.mem_os.main.MOSCore.chat")
def test_mos_chat_with_custom_prompt_no_cot(
mock_core_chat, mock_llm, mock_reader, mock_user_manager, simple_config
):
"""Test that MOS.chat passes base_prompt to MOSCore.chat when CoT is disabled."""
# Mock all dependencies
mock_llm.from_config.return_value = MagicMock()
mock_reader.from_config.return_value = MagicMock()
user_manager_instance = MagicMock()
user_manager_instance.validate_user.return_value = True
mock_user_manager.return_value = user_manager_instance

# Disable CoT
simple_config.PRO_MODE = False
mos = MOS(simple_config)

# Call chat with a custom prompt
custom_prompt = "You are a helpful bot."
mos.chat("Hello", user_id="test_user", base_prompt=custom_prompt)

# Assert that the core chat method was called with the custom prompt
mock_core_chat.assert_called_once_with("Hello", "test_user", base_prompt=custom_prompt)


@patch("memos.mem_os.core.UserManager")
@patch("memos.mem_os.core.MemReaderFactory")
@patch("memos.mem_os.core.LLMFactory")
@patch("memos.mem_os.main.MOS._generate_enhanced_response_with_context")
@patch("memos.mem_os.main.MOS.cot_decompose")
@patch("memos.mem_os.main.MOS.get_sub_answers")
def test_mos_chat_with_custom_prompt_with_cot(
mock_get_sub_answers,
mock_cot_decompose,
mock_generate_enhanced_response,
mock_llm,
mock_reader,
mock_user_manager,
simple_config,
):
"""Test that MOS.chat passes base_prompt correctly when CoT is enabled."""
# Mock dependencies
mock_llm.from_config.return_value = MagicMock()
mock_reader.from_config.return_value = MagicMock()
user_manager_instance = MagicMock()
user_manager_instance.validate_user.return_value = True
user_manager_instance.get_user_cubes.return_value = [MagicMock(cube_id="test_cube")]
mock_user_manager.return_value = user_manager_instance

# Mock CoT process
mock_cot_decompose.return_value = {"is_complex": True, "sub_questions": ["Sub-question 1"]}
mock_get_sub_answers.return_value = (["Sub-question 1"], ["Sub-answer 1"])

# Enable CoT
simple_config.PRO_MODE = True
mos = MOS(simple_config)

# Mock the search engine to avoid errors
mos.mem_cubes["test_cube"] = MagicMock()
mos.mem_cubes["test_cube"].text_mem = MagicMock()

# Call chat with a custom prompt
custom_prompt = "You are a super helpful bot. Context: {memories}"
mos.chat("Complex question", user_id="test_user", base_prompt=custom_prompt)

# Assert that the enhanced response generator was called with the prompt
mock_generate_enhanced_response.assert_called_once()
call_args = mock_generate_enhanced_response.call_args[1]
assert call_args.get("base_prompt") == custom_prompt
92 changes: 92 additions & 0 deletions tests/mem_os/test_memos_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,44 @@ def test_chat_with_memories(
assert mos.chat_history_manager["test_user"].chat_history[1]["role"] == "assistant"
assert mos.chat_history_manager["test_user"].chat_history[1]["content"] == response

@patch("memos.mem_os.core.UserManager")
@patch("memos.mem_os.core.MemReaderFactory")
@patch("memos.mem_os.core.LLMFactory")
def test_chat_with_custom_base_prompt(
self,
mock_llm_factory,
mock_reader_factory,
mock_user_manager_class,
mock_config,
mock_llm,
mock_mem_reader,
mock_user_manager,
mock_mem_cube,
):
"""Test chat functionality with a custom base prompt."""
# Setup mocks
mock_llm_factory.from_config.return_value = mock_llm
mock_reader_factory.from_config.return_value = mock_mem_reader
mock_user_manager_class.return_value = mock_user_manager

mos = MOSCore(MOSConfig(**mock_config))
mos.mem_cubes["test_cube_1"] = mock_mem_cube
mos.mem_cubes["test_cube_2"] = mock_mem_cube

custom_prompt = "You are a pirate. Answer as such. User memories: {memories}"
mos.chat("What do I like?", base_prompt=custom_prompt)

# Verify that the system prompt passed to the LLM is the custom one
mock_llm.generate.assert_called_once()
call_args = mock_llm.generate.call_args[0]
messages = call_args[0]
system_prompt = messages[0]["content"]

assert "You are a pirate." in system_prompt
assert "You are a knowledgeable and helpful AI assistant." not in system_prompt
assert "User memories:" in system_prompt
assert "I like playing football" in system_prompt # Check if memory is interpolated

@patch("memos.mem_os.core.UserManager")
@patch("memos.mem_os.core.MemReaderFactory")
@patch("memos.mem_os.core.LLMFactory")
Expand Down Expand Up @@ -664,6 +702,60 @@ def test_clear_messages(
assert mos.chat_history_manager["test_user"].user_id == "test_user"


class TestMOSSystemPrompt:
"""Test the _build_system_prompt method in MOSCore."""

@pytest.fixture
def mos_core_instance(self, mock_config, mock_user_manager):
"""Fixture to create a MOSCore instance for testing the prompt builder."""
with patch("memos.mem_os.core.LLMFactory"), patch("memos.mem_os.core.MemReaderFactory"):
return MOSCore(MOSConfig(**mock_config), user_manager=mock_user_manager)

def test_build_prompt_with_template_and_memories(self, mos_core_instance):
"""Test prompt with a template and memories."""
base_prompt = "You are a sales agent. Here are past interactions: {memories}"
memories = [TextualMemoryItem(memory="User likes blue cars.")]
prompt = mos_core_instance._build_system_prompt(memories, base_prompt)
assert "You are a sales agent." in prompt
assert "1. User likes blue cars." in prompt
assert "{memories}" not in prompt

def test_build_prompt_with_template_no_memories(self, mos_core_instance):
"""Test prompt with a template but no memories."""
base_prompt = "You are a sales agent. Here are past interactions: {memories}"
prompt = mos_core_instance._build_system_prompt(None, base_prompt)
assert "You are a sales agent." in prompt
assert "Here are past interactions:" in prompt
# The placeholder should be replaced with an empty string
assert "{memories}" not in prompt
# Check that the output is clean
assert prompt.strip() == "You are a sales agent. Here are past interactions:"
assert "## Memories:" not in prompt

def test_build_prompt_no_template_with_memories(self, mos_core_instance):
"""Test prompt without a template but with memories (backward compatibility)."""
base_prompt = "You are a helpful assistant."
memories = [TextualMemoryItem(memory="User is a developer.")]
prompt = mos_core_instance._build_system_prompt(memories, base_prompt)
assert "You are a helpful assistant." in prompt
assert "## Memories:" in prompt
assert "1. User is a developer." in prompt

def test_build_prompt_default_with_memories(self, mos_core_instance):
"""Test default prompt with memories."""
memories = [TextualMemoryItem(memory="User lives in New York.")]
prompt = mos_core_instance._build_system_prompt(memories)
assert "You are a knowledgeable and helpful AI assistant." in prompt
assert "## Memories:" in prompt
assert "1. User lives in New York." in prompt

def test_build_prompt_default_no_memories(self, mos_core_instance):
"""Test default prompt without any memories."""
prompt = mos_core_instance._build_system_prompt()
assert "You are a knowledgeable and helpful AI assistant." in prompt
assert "## Memories:" not in prompt


class TestMOSErrorHandling:
"""Test MOS error handling."""

Expand Down