Skip to content

Commit a7cf50e

Browse files
authored
🐛 When generating prompts using deep thinking models, it is necessary to remove the deep thinking content. #1229
2 parents 2e2b205 + 48266ba commit a7cf50e

File tree

6 files changed

+227
-199
lines changed

6 files changed

+227
-199
lines changed

backend/consts/const.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,7 @@
239239
"PROCESS_FAILED": "PROCESS_FAILED",
240240
"FORWARD_FAILED": "FORWARD_FAILED",
241241
}
242+
243+
# Deep Thinking Constants
244+
THINK_START_PATTERN = "<think>"
245+
THINK_END_PATTERN = "</think>"

backend/services/conversation_management_service.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, Dict, List, Optional
66

77
from jinja2 import StrictUndefined, Template
8-
from nexent.core.utils.observer import ProcessType
98
from smolagents import OpenAIServerModel
109

1110
from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE
@@ -28,9 +27,10 @@
2827
rename_conversation,
2928
update_message_opinion
3029
)
30+
from nexent.core.utils.observer import ProcessType
3131
from utils.config_utils import get_model_name_from_config, tenant_config_manager
3232
from utils.prompt_template_utils import get_generate_title_prompt_template
33-
from utils.str_utils import add_no_think_token, remove_think_tags
33+
from utils.str_utils import remove_think_blocks
3434

3535
logger = logging.getLogger("conversation_management_service")
3636

@@ -274,12 +274,11 @@ def call_llm_for_title(content: str, tenant_id: str, language: str = LANGUAGE["Z
274274
"content": prompt_template["SYSTEM_PROMPT"]},
275275
{"role": MESSAGE_ROLE["USER"],
276276
"content": user_prompt}]
277-
add_no_think_token(messages)
278277

279278
# Call the model
280279
response = llm(messages, max_tokens=10)
281280

282-
return remove_think_tags(response.content.strip())
281+
return remove_think_blocks(response.content.strip())
283282

284283

285284
def update_conversation_title(conversation_id: int, title: str, user_id: str = None) -> bool:

backend/services/prompt_service.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,47 @@
66
from jinja2 import StrictUndefined, Template
77
from smolagents import OpenAIServerModel
88

9-
from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE
9+
from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, THINK_END_PATTERN, THINK_START_PATTERN
1010
from consts.model import AgentInfoRequest
1111
from database.agent_db import update_agent, query_sub_agents_id_list, search_agent_info_by_agent_id
1212
from database.tool_db import query_tools_by_ids
1313
from services.agent_service import get_enable_tool_id_by_agent_id
1414
from utils.config_utils import tenant_config_manager, get_model_name_from_config
1515
from utils.prompt_template_utils import get_prompt_generate_prompt_template
16-
from utils.str_utils import remove_think_tags, add_no_think_token
1716

1817
# Configure logging
1918
logger = logging.getLogger("prompt_service")
2019

2120

21+
def _process_thinking_tokens(new_token: str, is_thinking: bool, token_join: list, callback=None) -> bool:
22+
"""
23+
Process tokens to filter out thinking content between <think> and </think> tags
24+
25+
Args:
26+
new_token: Current token from LLM stream
27+
is_thinking: Current thinking state
28+
token_join: List to accumulate non-thinking tokens
29+
callback: Callback function for streaming output
30+
31+
Returns:
32+
bool: updated_is_thinking
33+
"""
34+
# Handle thinking mode
35+
if is_thinking:
36+
return not (THINK_END_PATTERN in new_token)
37+
38+
# Handle start of thinking
39+
if THINK_START_PATTERN in new_token:
40+
return True
41+
42+
# Normal token processing
43+
token_join.append(new_token)
44+
if callback:
45+
callback("".join(token_join))
46+
47+
return False
48+
49+
2250
def call_llm_for_system_prompt(user_prompt: str, system_prompt: str, callback=None, tenant_id: str = None) -> str:
2351
"""
2452
Call LLM to generate system prompt
@@ -45,7 +73,6 @@ def call_llm_for_system_prompt(user_prompt: str, system_prompt: str, callback=No
4573
)
4674
messages = [{"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt},
4775
{"role": MESSAGE_ROLE["USER"], "content": user_prompt}]
48-
add_no_think_token(messages)
4976
try:
5077
completion_kwargs = llm._prepare_completion_kwargs(
5178
messages=messages,
@@ -56,14 +83,13 @@ def call_llm_for_system_prompt(user_prompt: str, system_prompt: str, callback=No
5683
current_request = llm.client.chat.completions.create(
5784
stream=True, **completion_kwargs)
5885
token_join = []
86+
is_thinking = False
5987
for chunk in current_request:
6088
new_token = chunk.choices[0].delta.content
6189
if new_token is not None:
62-
new_token = remove_think_tags(new_token)
63-
token_join.append(new_token)
64-
current_text = "".join(token_join)
65-
if callback is not None:
66-
callback(current_text)
90+
is_thinking = _process_thinking_tokens(
91+
new_token, is_thinking, token_join, callback
92+
)
6793
return "".join(token_join)
6894
except Exception as e:
6995
logger.error(f"Failed to generate prompt from LLM: {str(e)}")

backend/utils/str_utils.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,8 @@
1-
from typing import List
1+
import re
22

33

4-
def remove_think_tags(text: str) -> str:
5-
"""
6-
Remove thinking tags from text
7-
8-
Args:
9-
text: Input text that may contain thinking tags
10-
11-
Returns:
12-
str: Text with thinking tags removed
13-
"""
14-
return text.replace("<think>", "").replace("</think>", "")
15-
16-
17-
def add_no_think_token(messages: List[dict]):
18-
if not messages:
19-
return
20-
if messages[-1]["role"] == "user" and "content" in messages[-1]:
21-
messages[-1]["content"] += " /no_think"
4+
def remove_think_blocks(text: str) -> str:
5+
"""Remove <think>...</think> blocks including inner content."""
6+
if not text:
7+
return text
8+
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL | re.IGNORECASE)

test/backend/services/test_prompt_service.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
get_enabled_tool_description_for_generate_prompt,
2626
get_enabled_sub_agent_description_for_generate_prompt,
2727
generate_system_prompt,
28-
join_info_for_generate_system_prompt
28+
join_info_for_generate_system_prompt,
29+
_process_thinking_tokens
2930
)
3031

3132

@@ -38,17 +39,14 @@ def setUp(self):
3839
@patch('backend.services.prompt_service.OpenAIServerModel')
3940
@patch('backend.services.prompt_service.tenant_config_manager')
4041
@patch('backend.services.prompt_service.get_model_name_from_config')
41-
@patch('backend.services.prompt_service.remove_think_tags')
42-
def test_call_llm_for_system_prompt(self, mock_remove_think_tags,
43-
mock_get_model_name, mock_tenant_config, mock_openai):
42+
def test_call_llm_for_system_prompt(self, mock_get_model_name, mock_tenant_config, mock_openai):
4443
# Setup
4544
mock_model_config = {
4645
"base_url": "http://example.com",
4746
"api_key": "fake-key"
4847
}
4948
mock_tenant_config.get_model_config.return_value = mock_model_config
5049
mock_get_model_name.return_value = "gpt-4"
51-
mock_remove_think_tags.side_effect = lambda x: x # Return input unchanged
5250

5351
mock_llm_instance = mock_openai.return_value
5452

@@ -487,6 +485,147 @@ def test_call_llm_for_system_prompt_exception(self, mock_get_model_name, mock_te
487485

488486
self.assertIn("LLM error", str(context.exception))
489487

488+
def test_process_thinking_tokens_normal_token(self):
489+
"""Test process_thinking_tokens with normal token when not thinking"""
490+
token_join = []
491+
callback_calls = []
492+
493+
def mock_callback(text):
494+
callback_calls.append(text)
495+
496+
is_thinking = _process_thinking_tokens(
497+
"Hello", False, token_join, mock_callback)
498+
499+
self.assertFalse(is_thinking)
500+
self.assertEqual(token_join, ["Hello"])
501+
self.assertEqual(callback_calls, ["Hello"])
502+
503+
def test_process_thinking_tokens_start_thinking(self):
504+
"""Test process_thinking_tokens when encountering <think> tag"""
505+
token_join = []
506+
callback_calls = []
507+
508+
def mock_callback(text):
509+
callback_calls.append(text)
510+
511+
is_thinking = _process_thinking_tokens(
512+
"<think>", False, token_join, mock_callback)
513+
514+
self.assertTrue(is_thinking)
515+
self.assertEqual(token_join, [])
516+
self.assertEqual(callback_calls, [])
517+
518+
def test_process_thinking_tokens_content_while_thinking(self):
519+
"""Test process_thinking_tokens with content while in thinking mode"""
520+
token_join = ["Hello"]
521+
callback_calls = []
522+
523+
def mock_callback(text):
524+
callback_calls.append(text)
525+
526+
is_thinking = _process_thinking_tokens(
527+
"thinking content", True, token_join, mock_callback)
528+
529+
self.assertTrue(is_thinking)
530+
self.assertEqual(token_join, ["Hello"]) # Should not change
531+
self.assertEqual(callback_calls, [])
532+
533+
def test_process_thinking_tokens_end_thinking(self):
534+
"""Test process_thinking_tokens when encountering </think> tag"""
535+
token_join = ["Hello"]
536+
callback_calls = []
537+
538+
def mock_callback(text):
539+
callback_calls.append(text)
540+
541+
is_thinking = _process_thinking_tokens(
542+
"</think>", True, token_join, mock_callback)
543+
544+
self.assertFalse(is_thinking)
545+
self.assertEqual(token_join, ["Hello"]) # Should not change
546+
self.assertEqual(callback_calls, [])
547+
548+
def test_process_thinking_tokens_content_after_thinking(self):
549+
"""Test process_thinking_tokens with content after thinking ends"""
550+
token_join = ["Hello"]
551+
callback_calls = []
552+
553+
def mock_callback(text):
554+
callback_calls.append(text)
555+
556+
is_thinking = _process_thinking_tokens(
557+
"World", False, token_join, mock_callback)
558+
559+
self.assertFalse(is_thinking)
560+
self.assertEqual(token_join, ["Hello", "World"])
561+
self.assertEqual(callback_calls, ["HelloWorld"])
562+
563+
def test_process_thinking_tokens_complete_flow(self):
564+
"""Test process_thinking_tokens with complete thinking flow"""
565+
token_join = []
566+
callback_calls = []
567+
568+
def mock_callback(text):
569+
callback_calls.append(text)
570+
571+
# Start with normal content
572+
is_thinking = _process_thinking_tokens(
573+
"Start ", False, token_join, mock_callback)
574+
self.assertFalse(is_thinking)
575+
576+
# Enter thinking mode
577+
is_thinking = _process_thinking_tokens(
578+
"<think>", False, token_join, mock_callback)
579+
self.assertTrue(is_thinking)
580+
581+
# Thinking content (ignored)
582+
is_thinking = _process_thinking_tokens(
583+
"thinking", True, token_join, mock_callback)
584+
self.assertTrue(is_thinking)
585+
586+
# More thinking content (ignored)
587+
is_thinking = _process_thinking_tokens(
588+
" more", True, token_join, mock_callback)
589+
self.assertTrue(is_thinking)
590+
591+
# End thinking
592+
is_thinking = _process_thinking_tokens(
593+
"</think>", True, token_join, mock_callback)
594+
self.assertFalse(is_thinking)
595+
596+
# Continue with normal content
597+
is_thinking = _process_thinking_tokens(
598+
" End", False, token_join, mock_callback)
599+
self.assertFalse(is_thinking)
600+
601+
# Verify final state
602+
self.assertEqual(token_join, ["Start ", " End"])
603+
self.assertEqual(callback_calls, ["Start ", "Start End"])
604+
605+
def test_process_thinking_tokens_no_callback(self):
606+
"""Test process_thinking_tokens without callback function"""
607+
token_join = []
608+
609+
is_thinking = _process_thinking_tokens("Hello", False, token_join, None)
610+
611+
self.assertFalse(is_thinking)
612+
self.assertEqual(token_join, ["Hello"])
613+
614+
def test_process_thinking_tokens_empty_token(self):
615+
"""Test process_thinking_tokens with empty token"""
616+
token_join = []
617+
callback_calls = []
618+
619+
def mock_callback(text):
620+
callback_calls.append(text)
621+
622+
is_thinking = _process_thinking_tokens(
623+
"", False, token_join, mock_callback)
624+
625+
self.assertFalse(is_thinking)
626+
self.assertEqual(token_join, [""])
627+
self.assertEqual(callback_calls, [""])
628+
490629

491630
if __name__ == '__main__':
492631
unittest.main()

0 commit comments

Comments
 (0)