Skip to content

Commit 99cf9b2

Browse files
author
jiangpeiling
committed
✨ Add adapt to deep thinking model.
1 parent 30a3c62 commit 99cf9b2

File tree

6 files changed

+234
-20
lines changed

6 files changed

+234
-20
lines changed

backend/agents/create_agent_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def create_model_config_list(tenant_id):
3434
model_name=get_model_name_from_config(sub_model_config) if sub_model_config.get(
3535
"model_name") else "",
3636
url=sub_model_config.get("base_url", ""),
37-
is_deep_thinking=main_model_config.get("is_deep_thinking", False))]
37+
is_deep_thinking=sub_model_config.get("is_deep_thinking", False))]
3838

3939

4040
async def create_agent_config(agent_id, tenant_id, user_id, language: str = 'zh'):

backend/utils/str_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ def remove_think_tags(text: str) -> str:
1515

1616

1717
def add_no_think_token(messages: List[dict]):
18-
if messages[-1]["role"] == "user":
18+
if not messages:
19+
return
20+
if messages[-1]["role"] == "user" and "content" in messages[-1]:
1921
messages[-1]["content"] += " /no_think"

sdk/nexent/core/agents/nexent_agent.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,15 @@ def __init__(self, observer: MessageObserver,
6868

6969
def create_model(self, model_cite_name: str):
7070
"""create a model instance"""
71-
try:
72-
# Filter out None values and find matching model config
73-
model_config = next(
74-
(model_config for model_config in self.model_config_list
75-
if model_config is not None and model_config.cite_name == model_cite_name),
76-
None
77-
)
78-
if model_config is None:
79-
raise ValueError(f"Model {model_cite_name} not found")
80-
return ModelFactory.create_model(model_config, self.observer, self.stop_event)
81-
except StopIteration:
71+
# Filter out None values and find matching model config
72+
model_config = next(
73+
(model_config for model_config in self.model_config_list
74+
if model_config is not None and model_config.cite_name == model_cite_name),
75+
None
76+
)
77+
if model_config is None:
8278
raise ValueError(f"Model {model_cite_name} not found")
79+
return ModelFactory.create_model(model_config, self.observer, self.stop_event)
8380

8481
def create_local_tool(self, tool_config: ToolConfig):
8582
class_name = tool_config.class_name

sdk/nexent/core/models/openai_deep_thinking_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List
4949
new_token = self.process_token(new_token)
5050

5151
# If in think block, process as deep thinking content
52-
if self.observer.in_think_block and new_token:
52+
if self.observer.in_think_block:
5353
self.observer.message_query.append(
5454
Message(ProcessType.MODEL_OUTPUT_DEEP_THINKING, new_token).to_json()
5555
)

sdk/nexent/core/models/openai_llm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def __init__(self, observer: MessageObserver, temperature=0.2, top_p=0.95, *args
2323
def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List[str]] = None,
2424
grammar: Optional[str] = None, tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage:
2525
try:
26-
# 如果启用no_think,添加/no_think后缀到用户最后一条消息
27-
if messages[-1]["role"] == "user":
26+
if messages and isinstance(messages[-1], dict) and messages[-1].get("role") == "user":
2827
messages[-1]["content"][-1]['text'] += " /no_think"
2928

3029
completion_kwargs = self._prepare_completion_kwargs(messages=messages, stop_sequences=stop_sequences,

test/sdk/core/models/test_openai_deep_thinking_llm.py

Lines changed: 220 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def _prepare_completion_kwargs(self, *args, **kwargs):
2323
# In tests we will patch this method on the instance directly, so default impl is fine
2424
return {}
2525

26+
def postprocess_message(self, message, tools_to_call_from=None):
27+
# Return the message as-is for testing
28+
return message
29+
2630

2731
mock_models_module.OpenAIServerModel = DummyOpenAIServerModel
2832
mock_models_module.ChatMessage = MagicMock()
@@ -59,6 +63,11 @@ def deep_thinking_model_instance():
5963
model.top_p = 1.0
6064
model.custom_role_conversions = {}
6165

66+
# Create a proper mock for stop_event that returns False by default
67+
mock_stop_event = MagicMock()
68+
mock_stop_event.is_set.return_value = False
69+
model.stop_event = mock_stop_event
70+
6271
# Client hierarchy: client.chat.completions.create
6372
mock_client = MagicMock()
6473
mock_chat = MagicMock()
@@ -95,7 +104,7 @@ def test_process_token_with_think_tags(deep_thinking_model_instance):
95104
@pytest.mark.asyncio
96105
async def test_call_with_token_limit_error(deep_thinking_model_instance):
97106
"""Test __call__ method handles token limit errors correctly."""
98-
messages = [{"role": "user", "content": "test message"}]
107+
messages = [{"role": "user", "content": [{"text": "test message"}]}]
99108

100109
# Mock an error response
101110
deep_thinking_model_instance.client.chat.completions.create.side_effect = Exception("context_length_exceeded")
@@ -110,19 +119,226 @@ async def test_call_with_token_limit_error(deep_thinking_model_instance):
110119
@pytest.mark.asyncio
111120
async def test_call_with_stop_event(deep_thinking_model_instance):
112121
"""Test __call__ method handles stop event correctly."""
113-
messages = [{"role": "user", "content": "test message"}]
122+
messages = [{"role": "user", "content": [{"text": "test message"}]}]
114123

115124
# Set up mock chunks that will be interrupted
116125
mock_chunks = [
117126
MagicMock(choices=[MagicMock(delta=MagicMock(content="Start ", role="assistant"))])
118127
]
119128
deep_thinking_model_instance.client.chat.completions.create.return_value = mock_chunks
120129

121-
# Set the stop event
122-
deep_thinking_model_instance.stop_event.set()
130+
# Configure the stop event to return True when is_set() is called
131+
deep_thinking_model_instance.stop_event.is_set.return_value = True
123132

124133
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs"), pytest.raises(
125134
RuntimeError) as exc_info:
126135
deep_thinking_model_instance(messages)
127136

128137
assert "Model is interrupted by stop event" in str(exc_info.value)
138+
139+
140+
# ---------------------------------------------------------------------------
141+
# Tests for token processing and output generation
142+
# ---------------------------------------------------------------------------
143+
144+
def test_call_normal_operation_with_usage_tracking(deep_thinking_model_instance):
145+
"""Test __call__ method with normal operation and usage tracking."""
146+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
147+
148+
# Mock the stream response with usage info
149+
mock_chunk1 = MagicMock()
150+
mock_chunk1.choices = [MagicMock()]
151+
mock_chunk1.choices[0].delta.content = "Hello"
152+
mock_chunk1.choices[0].delta.role = "assistant"
153+
154+
mock_chunk2 = MagicMock()
155+
mock_chunk2.choices = [MagicMock()]
156+
mock_chunk2.choices[0].delta.content = " world"
157+
mock_chunk2.choices[0].delta.role = None
158+
mock_chunk2.usage = MagicMock()
159+
mock_chunk2.usage.prompt_tokens = 10
160+
mock_chunk2.usage.total_tokens = 15
161+
162+
mock_stream = [mock_chunk1, mock_chunk2]
163+
164+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
165+
deep_thinking_model_instance.client.chat.completions.create.return_value = mock_stream
166+
167+
# Call the method
168+
result = deep_thinking_model_instance.__call__(messages)
169+
170+
# Verify observer calls
171+
deep_thinking_model_instance.observer.add_model_new_token.assert_any_call("Hello")
172+
deep_thinking_model_instance.observer.add_model_new_token.assert_any_call(" world")
173+
deep_thinking_model_instance.observer.flush_remaining_tokens.assert_called_once()
174+
175+
# Verify token counts were set
176+
assert deep_thinking_model_instance.last_input_token_count == 10
177+
assert deep_thinking_model_instance.last_output_token_count == 15
178+
179+
# Verify result is a ChatMessage
180+
assert isinstance(result, MagicMock) # Since we're mocking the parent class method
181+
182+
183+
def test_call_with_no_usage_info(deep_thinking_model_instance):
184+
"""Test __call__ method handles case where usage info is None."""
185+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
186+
187+
# Mock the stream response with no usage info
188+
mock_chunk = MagicMock()
189+
mock_chunk.choices = [MagicMock()]
190+
mock_chunk.choices[0].delta.content = "Response"
191+
mock_chunk.choices[0].delta.role = "assistant"
192+
mock_chunk.usage = None
193+
194+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
195+
deep_thinking_model_instance.client.chat.completions.create.return_value = [mock_chunk]
196+
197+
# Call the method
198+
deep_thinking_model_instance.__call__(messages)
199+
200+
# Verify token counts are set to 0 when usage is None
201+
assert deep_thinking_model_instance.last_input_token_count == 0
202+
assert deep_thinking_model_instance.last_output_token_count == 0
203+
204+
205+
def test_call_with_deep_thinking_tokens(deep_thinking_model_instance):
206+
"""Test __call__ method processes deep thinking tokens correctly."""
207+
messages = [{"role": "user", "content": [{"text": "Think about this"}]}]
208+
209+
# Mock the stream response with think tags
210+
mock_chunk1 = MagicMock()
211+
mock_chunk1.choices = [MagicMock()]
212+
mock_chunk1.choices[0].delta.content = "<think>"
213+
mock_chunk1.choices[0].delta.role = "assistant"
214+
215+
mock_chunk2 = MagicMock()
216+
mock_chunk2.choices = [MagicMock()]
217+
mock_chunk2.choices[0].delta.content = "deep thinking"
218+
mock_chunk2.choices[0].delta.role = None
219+
220+
mock_chunk3 = MagicMock()
221+
mock_chunk3.choices = [MagicMock()]
222+
mock_chunk3.choices[0].delta.content = "</think>"
223+
mock_chunk3.choices[0].delta.role = None
224+
225+
mock_chunk4 = MagicMock()
226+
mock_chunk4.choices = [MagicMock()]
227+
mock_chunk4.choices[0].delta.content = "final answer"
228+
mock_chunk4.choices[0].delta.role = None
229+
mock_chunk4.usage = MagicMock()
230+
mock_chunk4.usage.prompt_tokens = 5
231+
mock_chunk4.usage.total_tokens = 8
232+
233+
mock_stream = [mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4]
234+
235+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
236+
deep_thinking_model_instance.client.chat.completions.create.return_value = mock_stream
237+
238+
# Call the method
239+
deep_thinking_model_instance.__call__(messages)
240+
241+
# Verify that deep thinking tokens were processed correctly
242+
# The think tags should be removed and content should be added to message_query
243+
assert deep_thinking_model_instance.observer.in_think_block is False # Should end as False
244+
deep_thinking_model_instance.observer.add_model_new_token.assert_any_call("final answer")
245+
246+
247+
def test_call_with_mixed_thinking_and_normal_tokens(deep_thinking_model_instance):
248+
"""Test __call__ method handles mixed thinking and normal tokens."""
249+
messages = [{"role": "user", "content": [{"text": "Mixed content"}]}]
250+
251+
# Mock the stream response with mixed content
252+
mock_chunk1 = MagicMock()
253+
mock_chunk1.choices = [MagicMock()]
254+
mock_chunk1.choices[0].delta.content = "Normal "
255+
mock_chunk1.choices[0].delta.role = "assistant"
256+
257+
mock_chunk2 = MagicMock()
258+
mock_chunk2.choices = [MagicMock()]
259+
mock_chunk2.choices[0].delta.content = "<think>thinking"
260+
mock_chunk2.choices[0].delta.role = None
261+
262+
mock_chunk3 = MagicMock()
263+
mock_chunk3.choices = [MagicMock()]
264+
mock_chunk3.choices[0].delta.content = "</think>"
265+
mock_chunk3.choices[0].delta.role = None
266+
267+
mock_chunk4 = MagicMock()
268+
mock_chunk4.choices = [MagicMock()]
269+
mock_chunk4.choices[0].delta.content = " more normal"
270+
mock_chunk4.choices[0].delta.role = None
271+
mock_chunk4.usage = MagicMock()
272+
mock_chunk4.usage.prompt_tokens = 8
273+
mock_chunk4.usage.total_tokens = 12
274+
275+
mock_stream = [mock_chunk1, mock_chunk2, mock_chunk3, mock_chunk4]
276+
277+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
278+
deep_thinking_model_instance.client.chat.completions.create.return_value = mock_stream
279+
280+
# Call the method
281+
deep_thinking_model_instance.__call__(messages)
282+
283+
# Verify that normal tokens were added to observer
284+
deep_thinking_model_instance.observer.add_model_new_token.assert_any_call("Normal ")
285+
deep_thinking_model_instance.observer.add_model_new_token.assert_any_call(" more normal")
286+
287+
# Verify token counts
288+
assert deep_thinking_model_instance.last_input_token_count == 8
289+
assert deep_thinking_model_instance.last_output_token_count == 12
290+
291+
292+
def test_call_with_null_tokens(deep_thinking_model_instance):
293+
"""Test __call__ method handles null tokens in stream."""
294+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
295+
296+
# Mock the stream response with null tokens
297+
mock_chunk1 = MagicMock()
298+
mock_chunk1.choices = [MagicMock()]
299+
mock_chunk1.choices[0].delta.content = None
300+
mock_chunk1.choices[0].delta.role = "assistant"
301+
302+
mock_chunk2 = MagicMock()
303+
mock_chunk2.choices = [MagicMock()]
304+
mock_chunk2.choices[0].delta.content = "Response"
305+
mock_chunk2.choices[0].delta.role = None
306+
mock_chunk2.usage = MagicMock()
307+
mock_chunk2.usage.prompt_tokens = 5
308+
mock_chunk2.usage.total_tokens = 8
309+
310+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
311+
deep_thinking_model_instance.client.chat.completions.create.return_value = [mock_chunk1, mock_chunk2]
312+
313+
# Call the method
314+
deep_thinking_model_instance.__call__(messages)
315+
316+
# Verify that null tokens are handled correctly (not added to observer)
317+
deep_thinking_model_instance.observer.add_model_new_token.assert_called_once_with("Response")
318+
319+
320+
def test_call_with_general_exception(deep_thinking_model_instance):
321+
"""Test __call__ method re-raises general exceptions."""
322+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
323+
324+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
325+
# Mock the client to raise a general exception
326+
deep_thinking_model_instance.client.chat.completions.create.side_effect = Exception("General error")
327+
328+
# Call the method and expect the same exception
329+
with pytest.raises(Exception, match="General error"):
330+
deep_thinking_model_instance.__call__(messages)
331+
332+
333+
def test_call_with_context_length_exceeded_error(deep_thinking_model_instance):
334+
"""Test __call__ method handles context_length_exceeded error correctly."""
335+
messages = [{"role": "user", "content": [{"text": "Hello"}]}]
336+
337+
with patch.object(deep_thinking_model_instance, "_prepare_completion_kwargs", return_value={}):
338+
# Mock the client to raise context length exceeded error
339+
deep_thinking_model_instance.client.chat.completions.create.side_effect = Exception(
340+
"context_length_exceeded: token limit exceeded")
341+
342+
# Call the method and expect ValueError
343+
with pytest.raises(ValueError, match="Token limit exceeded"):
344+
deep_thinking_model_instance.__call__(messages)

0 commit comments

Comments
 (0)