Skip to content

Commit a0990be

Browse files
newmeta-benfridayL
andauthored
Support customized system prompt base (#102)
feat: Add parameter to chat functions to override the base system prompt with a custom system prompt. Co-authored-by: chunyu li <[email protected]>
1 parent bfc32d8 commit a0990be

File tree

4 files changed

+212
-23
lines changed

4 files changed

+212
-23
lines changed

src/memos/mem_os/core.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,16 @@ def _get_all_documents(self, path: str) -> list[str]:
210210
documents.append(str(file_path))
211211
return documents
212212

213-
def chat(self, query: str, user_id: str | None = None) -> str:
213+
def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = None) -> str:
214214
"""
215215
Chat with the MOS.
216216
217217
Args:
218218
query (str): The user's query.
219+
user_id (str, optional): The user ID for the chat session. Defaults to the user ID from the config.
220+
base_prompt (str, optional): A custom base prompt to use for the chat.
221+
It can be a template string with a `{memories}` placeholder.
222+
If not provided, a default prompt is used.
219223
220224
Returns:
221225
str: The response from the MOS.
@@ -251,9 +255,9 @@ def chat(self, query: str, user_id: str | None = None) -> str:
251255
memories = mem_cube.text_mem.search(query, top_k=self.config.top_k)
252256
memories_all.extend(memories)
253257
logger.info(f"🧠 [Memory] Searched memories:\n{self._str_memories(memories_all)}\n")
254-
system_prompt = self._build_system_prompt(memories_all)
258+
system_prompt = self._build_system_prompt(memories_all, base_prompt=base_prompt)
255259
else:
256-
system_prompt = self._build_system_prompt()
260+
system_prompt = self._build_system_prompt(base_prompt=base_prompt)
257261
current_messages = [
258262
{"role": "system", "content": system_prompt},
259263
*chat_history.chat_history,
@@ -302,27 +306,38 @@ def chat(self, query: str, user_id: str | None = None) -> str:
302306
return response
303307

304308
def _build_system_prompt(
305-
self, memories: list[TextualMemoryItem] | list[str] | None = None
309+
self,
310+
memories: list[TextualMemoryItem] | list[str] | None = None,
311+
base_prompt: str | None = None,
306312
) -> str:
307313
"""Build system prompt with optional memories context."""
308-
base_prompt = (
309-
"You are a knowledgeable and helpful AI assistant. "
310-
"You have access to conversation memories that help you provide more personalized responses. "
311-
"Use the memories to understand the user's context, preferences, and past interactions. "
312-
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
313-
)
314+
if base_prompt is None:
315+
base_prompt = (
316+
"You are a knowledgeable and helpful AI assistant. "
317+
"You have access to conversation memories that help you provide more personalized responses. "
318+
"Use the memories to understand the user's context, preferences, and past interactions. "
319+
"If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
320+
)
314321

322+
memory_context = ""
315323
if memories:
316-
memory_context = "\n\n## Memories:\n"
324+
memory_list = []
317325
for i, memory in enumerate(memories, 1):
318326
if isinstance(memory, TextualMemoryItem):
319327
text_memory = memory.memory
320328
else:
321329
if not isinstance(memory, str):
322330
logger.error("Unexpected memory type.")
323331
text_memory = memory
324-
memory_context += f"{i}. {text_memory}\n"
325-
return base_prompt + memory_context
332+
memory_list.append(f"{i}. {text_memory}")
333+
memory_context = "\n".join(memory_list)
334+
335+
if "{memories}" in base_prompt:
336+
return base_prompt.format(memories=memory_context)
337+
elif memories:
338+
# For backward compatibility, append memories if no placeholder is found
339+
memory_context_with_header = "\n\n## Memories:\n" + memory_context
340+
return base_prompt + memory_context_with_header
326341
return base_prompt
327342

328343
def _str_memories(

src/memos/mem_os/main.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,16 @@ def __init__(self, config: MOSConfig):
3131
logger.info(PRO_MODE_WELCOME_MESSAGE)
3232
super().__init__(config)
3333

34-
def chat(self, query: str, user_id: str | None = None) -> str:
34+
def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = None) -> str:
3535
"""
3636
Enhanced chat method with optional CoT (Chain of Thought) enhancement.
3737
3838
Args:
3939
query (str): The user's query.
4040
user_id (str, optional): User ID for context.
41+
base_prompt (str, optional): A custom base prompt to use for the chat.
42+
It can be a template string with a `{memories}` placeholder.
43+
If not provided, a default prompt is used.
4144
4245
Returns:
4346
str: The response from the MOS.
@@ -46,12 +49,14 @@ def chat(self, query: str, user_id: str | None = None) -> str:
4649

4750
if not self.enable_cot:
4851
# Use the original chat method from core
49-
return super().chat(query, user_id)
52+
return super().chat(query, user_id, base_prompt=base_prompt)
5053

5154
# Enhanced chat with CoT decomposition
52-
return self._chat_with_cot_enhancement(query, user_id)
55+
return self._chat_with_cot_enhancement(query, user_id, base_prompt=base_prompt)
5356

54-
def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) -> str:
57+
def _chat_with_cot_enhancement(
58+
self, query: str, user_id: str | None = None, base_prompt: str | None = None
59+
) -> str:
5560
"""
5661
Chat with CoT enhancement for complex query decomposition.
5762
This method includes all the same validation and processing logic as the core chat method.
@@ -84,7 +89,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
8489
# Check if the query is complex and needs decomposition
8590
if not decomposition_result.get("is_complex", False):
8691
logger.info("🔍 [CoT] Query is not complex, using standard chat")
87-
return super().chat(query, user_id)
92+
return super().chat(query, user_id, base_prompt=base_prompt)
8893

8994
sub_questions = decomposition_result.get("sub_questions", [])
9095
logger.info(f"🔍 [CoT] Decomposed into {len(sub_questions)} sub-questions")
@@ -93,7 +98,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
9398
search_engine = self._get_search_engine_for_cot_with_validation(user_cube_ids)
9499
if not search_engine:
95100
logger.warning("🔍 [CoT] No search engine available, using standard chat")
96-
return super().chat(query, user_id)
101+
return super().chat(query, user_id, base_prompt=base_prompt)
97102

98103
# Step 4: Get answers for sub-questions
99104
logger.info("🔍 [CoT] Getting answers for sub-questions...")
@@ -115,6 +120,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
115120
chat_history=chat_history,
116121
user_id=target_user_id,
117122
search_engine=search_engine,
123+
base_prompt=base_prompt,
118124
)
119125

120126
# Step 6: Update chat history (same as core method)
@@ -149,7 +155,7 @@ def _chat_with_cot_enhancement(self, query: str, user_id: str | None = None) ->
149155
except Exception as e:
150156
logger.error(f"🔍 [CoT] Error in CoT enhancement: {e}")
151157
logger.info("🔍 [CoT] Falling back to standard chat")
152-
return super().chat(query, user_id)
158+
return super().chat(query, user_id, base_prompt=base_prompt)
153159

154160
def _get_search_engine_for_cot_with_validation(
155161
self, user_cube_ids: list[str]
@@ -183,6 +189,7 @@ def _generate_enhanced_response_with_context(
183189
chat_history: Any,
184190
user_id: str | None = None,
185191
search_engine: BaseTextMemory | None = None,
192+
base_prompt: str | None = None,
186193
) -> str:
187194
"""
188195
Generate an enhanced response using sub-questions and their answers, with chat context.
@@ -193,6 +200,8 @@ def _generate_enhanced_response_with_context(
193200
sub_answers (list[str]): List of answers to sub-questions.
194201
chat_history: The user's chat history.
195202
user_id (str, optional): User ID for context.
203+
search_engine (BaseTextMemory, optional): Search engine for context retrieval.
204+
base_prompt (str, optional): A custom base prompt for the chat.
196205
197206
Returns:
198207
str: The enhanced response.
@@ -213,10 +222,10 @@ def _generate_enhanced_response_with_context(
213222
original_query, top_k=self.config.top_k, mode="fast"
214223
)
215224
system_prompt = self._build_system_prompt(
216-
search_memories
225+
search_memories, base_prompt=base_prompt
217226
) # Use the same system prompt builder
218227
else:
219-
system_prompt = self._build_system_prompt()
228+
system_prompt = self._build_system_prompt(base_prompt=base_prompt)
220229
current_messages = [
221230
{"role": "system", "content": system_prompt + SYNTHESIS_PROMPT.format(qa_text=qa_text)},
222231
*chat_history.chat_history,
@@ -261,7 +270,7 @@ def _generate_enhanced_response_with_context(
261270
except Exception as e:
262271
logger.error(f"🔍 [CoT] Error generating enhanced response: {e}")
263272
# Fallback to standard chat
264-
return super().chat(original_query, user_id)
273+
return super().chat(original_query, user_id, base_prompt=base_prompt)
265274

266275
@classmethod
267276
def cot_decompose(

tests/mem_os/test_memos.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,76 @@ def test_mos_has_core_methods(mock_llm, mock_reader, mock_user_manager, simple_c
100100
assert callable(mos.chat)
101101
assert callable(mos.search)
102102
assert callable(mos.add)
103+
104+
105+
@patch("memos.mem_os.core.UserManager")
106+
@patch("memos.mem_os.core.MemReaderFactory")
107+
@patch("memos.mem_os.core.LLMFactory")
108+
@patch("memos.mem_os.main.MOSCore.chat")
109+
def test_mos_chat_with_custom_prompt_no_cot(
110+
mock_core_chat, mock_llm, mock_reader, mock_user_manager, simple_config
111+
):
112+
"""Test that MOS.chat passes base_prompt to MOSCore.chat when CoT is disabled."""
113+
# Mock all dependencies
114+
mock_llm.from_config.return_value = MagicMock()
115+
mock_reader.from_config.return_value = MagicMock()
116+
user_manager_instance = MagicMock()
117+
user_manager_instance.validate_user.return_value = True
118+
mock_user_manager.return_value = user_manager_instance
119+
120+
# Disable CoT
121+
simple_config.PRO_MODE = False
122+
mos = MOS(simple_config)
123+
124+
# Call chat with a custom prompt
125+
custom_prompt = "You are a helpful bot."
126+
mos.chat("Hello", user_id="test_user", base_prompt=custom_prompt)
127+
128+
# Assert that the core chat method was called with the custom prompt
129+
mock_core_chat.assert_called_once_with("Hello", "test_user", base_prompt=custom_prompt)
130+
131+
132+
@patch("memos.mem_os.core.UserManager")
133+
@patch("memos.mem_os.core.MemReaderFactory")
134+
@patch("memos.mem_os.core.LLMFactory")
135+
@patch("memos.mem_os.main.MOS._generate_enhanced_response_with_context")
136+
@patch("memos.mem_os.main.MOS.cot_decompose")
137+
@patch("memos.mem_os.main.MOS.get_sub_answers")
138+
def test_mos_chat_with_custom_prompt_with_cot(
139+
mock_get_sub_answers,
140+
mock_cot_decompose,
141+
mock_generate_enhanced_response,
142+
mock_llm,
143+
mock_reader,
144+
mock_user_manager,
145+
simple_config,
146+
):
147+
"""Test that MOS.chat passes base_prompt correctly when CoT is enabled."""
148+
# Mock dependencies
149+
mock_llm.from_config.return_value = MagicMock()
150+
mock_reader.from_config.return_value = MagicMock()
151+
user_manager_instance = MagicMock()
152+
user_manager_instance.validate_user.return_value = True
153+
user_manager_instance.get_user_cubes.return_value = [MagicMock(cube_id="test_cube")]
154+
mock_user_manager.return_value = user_manager_instance
155+
156+
# Mock CoT process
157+
mock_cot_decompose.return_value = {"is_complex": True, "sub_questions": ["Sub-question 1"]}
158+
mock_get_sub_answers.return_value = (["Sub-question 1"], ["Sub-answer 1"])
159+
160+
# Enable CoT
161+
simple_config.PRO_MODE = True
162+
mos = MOS(simple_config)
163+
164+
# Mock the search engine to avoid errors
165+
mos.mem_cubes["test_cube"] = MagicMock()
166+
mos.mem_cubes["test_cube"].text_mem = MagicMock()
167+
168+
# Call chat with a custom prompt
169+
custom_prompt = "You are a super helpful bot. Context: {memories}"
170+
mos.chat("Complex question", user_id="test_user", base_prompt=custom_prompt)
171+
172+
# Assert that the enhanced response generator was called with the prompt
173+
mock_generate_enhanced_response.assert_called_once()
174+
call_args = mock_generate_enhanced_response.call_args[1]
175+
assert call_args.get("base_prompt") == custom_prompt

tests/mem_os/test_memos_core.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,44 @@ def test_chat_with_memories(
592592
assert mos.chat_history_manager["test_user"].chat_history[1]["role"] == "assistant"
593593
assert mos.chat_history_manager["test_user"].chat_history[1]["content"] == response
594594

595+
@patch("memos.mem_os.core.UserManager")
596+
@patch("memos.mem_os.core.MemReaderFactory")
597+
@patch("memos.mem_os.core.LLMFactory")
598+
def test_chat_with_custom_base_prompt(
599+
self,
600+
mock_llm_factory,
601+
mock_reader_factory,
602+
mock_user_manager_class,
603+
mock_config,
604+
mock_llm,
605+
mock_mem_reader,
606+
mock_user_manager,
607+
mock_mem_cube,
608+
):
609+
"""Test chat functionality with a custom base prompt."""
610+
# Setup mocks
611+
mock_llm_factory.from_config.return_value = mock_llm
612+
mock_reader_factory.from_config.return_value = mock_mem_reader
613+
mock_user_manager_class.return_value = mock_user_manager
614+
615+
mos = MOSCore(MOSConfig(**mock_config))
616+
mos.mem_cubes["test_cube_1"] = mock_mem_cube
617+
mos.mem_cubes["test_cube_2"] = mock_mem_cube
618+
619+
custom_prompt = "You are a pirate. Answer as such. User memories: {memories}"
620+
mos.chat("What do I like?", base_prompt=custom_prompt)
621+
622+
# Verify that the system prompt passed to the LLM is the custom one
623+
mock_llm.generate.assert_called_once()
624+
call_args = mock_llm.generate.call_args[0]
625+
messages = call_args[0]
626+
system_prompt = messages[0]["content"]
627+
628+
assert "You are a pirate." in system_prompt
629+
assert "You are a knowledgeable and helpful AI assistant." not in system_prompt
630+
assert "User memories:" in system_prompt
631+
assert "I like playing football" in system_prompt # Check if memory is interpolated
632+
595633
@patch("memos.mem_os.core.UserManager")
596634
@patch("memos.mem_os.core.MemReaderFactory")
597635
@patch("memos.mem_os.core.LLMFactory")
@@ -664,6 +702,60 @@ def test_clear_messages(
664702
assert mos.chat_history_manager["test_user"].user_id == "test_user"
665703

666704

705+
class TestMOSSystemPrompt:
706+
"""Test the _build_system_prompt method in MOSCore."""
707+
708+
@pytest.fixture
709+
def mos_core_instance(self, mock_config, mock_user_manager):
710+
"""Fixture to create a MOSCore instance for testing the prompt builder."""
711+
with patch("memos.mem_os.core.LLMFactory"), patch("memos.mem_os.core.MemReaderFactory"):
712+
return MOSCore(MOSConfig(**mock_config), user_manager=mock_user_manager)
713+
714+
def test_build_prompt_with_template_and_memories(self, mos_core_instance):
715+
"""Test prompt with a template and memories."""
716+
base_prompt = "You are a sales agent. Here are past interactions: {memories}"
717+
memories = [TextualMemoryItem(memory="User likes blue cars.")]
718+
prompt = mos_core_instance._build_system_prompt(memories, base_prompt)
719+
assert "You are a sales agent." in prompt
720+
assert "1. User likes blue cars." in prompt
721+
assert "{memories}" not in prompt
722+
723+
def test_build_prompt_with_template_no_memories(self, mos_core_instance):
724+
"""Test prompt with a template but no memories."""
725+
base_prompt = "You are a sales agent. Here are past interactions: {memories}"
726+
prompt = mos_core_instance._build_system_prompt(None, base_prompt)
727+
assert "You are a sales agent." in prompt
728+
assert "Here are past interactions:" in prompt
729+
# The placeholder should be replaced with an empty string
730+
assert "{memories}" not in prompt
731+
# Check that the output is clean
732+
assert prompt.strip() == "You are a sales agent. Here are past interactions:"
733+
assert "## Memories:" not in prompt
734+
735+
def test_build_prompt_no_template_with_memories(self, mos_core_instance):
736+
"""Test prompt without a template but with memories (backward compatibility)."""
737+
base_prompt = "You are a helpful assistant."
738+
memories = [TextualMemoryItem(memory="User is a developer.")]
739+
prompt = mos_core_instance._build_system_prompt(memories, base_prompt)
740+
assert "You are a helpful assistant." in prompt
741+
assert "## Memories:" in prompt
742+
assert "1. User is a developer." in prompt
743+
744+
def test_build_prompt_default_with_memories(self, mos_core_instance):
745+
"""Test default prompt with memories."""
746+
memories = [TextualMemoryItem(memory="User lives in New York.")]
747+
prompt = mos_core_instance._build_system_prompt(memories)
748+
assert "You are a knowledgeable and helpful AI assistant." in prompt
749+
assert "## Memories:" in prompt
750+
assert "1. User lives in New York." in prompt
751+
752+
def test_build_prompt_default_no_memories(self, mos_core_instance):
753+
"""Test default prompt without any memories."""
754+
prompt = mos_core_instance._build_system_prompt()
755+
assert "You are a knowledgeable and helpful AI assistant." in prompt
756+
assert "## Memories:" not in prompt
757+
758+
667759
class TestMOSErrorHandling:
668760
"""Test MOS error handling."""
669761

0 commit comments

Comments
 (0)