2525 get_enabled_tool_description_for_generate_prompt ,
2626 get_enabled_sub_agent_description_for_generate_prompt ,
2727 generate_system_prompt ,
28- join_info_for_generate_system_prompt
28+ join_info_for_generate_system_prompt ,
29+ _process_thinking_tokens
2930 )
3031
3132
@@ -38,17 +39,14 @@ def setUp(self):
3839 @patch ('backend.services.prompt_service.OpenAIServerModel' )
3940 @patch ('backend.services.prompt_service.tenant_config_manager' )
4041 @patch ('backend.services.prompt_service.get_model_name_from_config' )
41- @patch ('backend.services.prompt_service.remove_think_tags' )
42- def test_call_llm_for_system_prompt (self , mock_remove_think_tags ,
43- mock_get_model_name , mock_tenant_config , mock_openai ):
42+ def test_call_llm_for_system_prompt (self , mock_get_model_name , mock_tenant_config , mock_openai ):
4443 # Setup
4544 mock_model_config = {
4645 "base_url" : "http://example.com" ,
4746 "api_key" : "fake-key"
4847 }
4948 mock_tenant_config .get_model_config .return_value = mock_model_config
5049 mock_get_model_name .return_value = "gpt-4"
51- mock_remove_think_tags .side_effect = lambda x : x # Return input unchanged
5250
5351 mock_llm_instance = mock_openai .return_value
5452
@@ -487,6 +485,147 @@ def test_call_llm_for_system_prompt_exception(self, mock_get_model_name, mock_te
487485
488486 self .assertIn ("LLM error" , str (context .exception ))
489487
488+ def test_process_thinking_tokens_normal_token (self ):
489+ """Test process_thinking_tokens with normal token when not thinking"""
490+ token_join = []
491+ callback_calls = []
492+
493+ def mock_callback (text ):
494+ callback_calls .append (text )
495+
496+ is_thinking = _process_thinking_tokens (
497+ "Hello" , False , token_join , mock_callback )
498+
499+ self .assertFalse (is_thinking )
500+ self .assertEqual (token_join , ["Hello" ])
501+ self .assertEqual (callback_calls , ["Hello" ])
502+
503+ def test_process_thinking_tokens_start_thinking (self ):
504+ """Test process_thinking_tokens when encountering <think> tag"""
505+ token_join = []
506+ callback_calls = []
507+
508+ def mock_callback (text ):
509+ callback_calls .append (text )
510+
511+ is_thinking = _process_thinking_tokens (
512+ "<think>" , False , token_join , mock_callback )
513+
514+ self .assertTrue (is_thinking )
515+ self .assertEqual (token_join , [])
516+ self .assertEqual (callback_calls , [])
517+
518+ def test_process_thinking_tokens_content_while_thinking (self ):
519+ """Test process_thinking_tokens with content while in thinking mode"""
520+ token_join = ["Hello" ]
521+ callback_calls = []
522+
523+ def mock_callback (text ):
524+ callback_calls .append (text )
525+
526+ is_thinking = _process_thinking_tokens (
527+ "thinking content" , True , token_join , mock_callback )
528+
529+ self .assertTrue (is_thinking )
530+ self .assertEqual (token_join , ["Hello" ]) # Should not change
531+ self .assertEqual (callback_calls , [])
532+
533+ def test_process_thinking_tokens_end_thinking (self ):
534+ """Test process_thinking_tokens when encountering </think> tag"""
535+ token_join = ["Hello" ]
536+ callback_calls = []
537+
538+ def mock_callback (text ):
539+ callback_calls .append (text )
540+
541+ is_thinking = _process_thinking_tokens (
542+ "</think>" , True , token_join , mock_callback )
543+
544+ self .assertFalse (is_thinking )
545+ self .assertEqual (token_join , ["Hello" ]) # Should not change
546+ self .assertEqual (callback_calls , [])
547+
548+ def test_process_thinking_tokens_content_after_thinking (self ):
549+ """Test process_thinking_tokens with content after thinking ends"""
550+ token_join = ["Hello" ]
551+ callback_calls = []
552+
553+ def mock_callback (text ):
554+ callback_calls .append (text )
555+
556+ is_thinking = _process_thinking_tokens (
557+ "World" , False , token_join , mock_callback )
558+
559+ self .assertFalse (is_thinking )
560+ self .assertEqual (token_join , ["Hello" , "World" ])
561+ self .assertEqual (callback_calls , ["HelloWorld" ])
562+
563+ def test_process_thinking_tokens_complete_flow (self ):
564+ """Test process_thinking_tokens with complete thinking flow"""
565+ token_join = []
566+ callback_calls = []
567+
568+ def mock_callback (text ):
569+ callback_calls .append (text )
570+
571+ # Start with normal content
572+ is_thinking = _process_thinking_tokens (
573+ "Start " , False , token_join , mock_callback )
574+ self .assertFalse (is_thinking )
575+
576+ # Enter thinking mode
577+ is_thinking = _process_thinking_tokens (
578+ "<think>" , False , token_join , mock_callback )
579+ self .assertTrue (is_thinking )
580+
581+ # Thinking content (ignored)
582+ is_thinking = _process_thinking_tokens (
583+ "thinking" , True , token_join , mock_callback )
584+ self .assertTrue (is_thinking )
585+
586+ # More thinking content (ignored)
587+ is_thinking = _process_thinking_tokens (
588+ " more" , True , token_join , mock_callback )
589+ self .assertTrue (is_thinking )
590+
591+ # End thinking
592+ is_thinking = _process_thinking_tokens (
593+ "</think>" , True , token_join , mock_callback )
594+ self .assertFalse (is_thinking )
595+
596+ # Continue with normal content
597+ is_thinking = _process_thinking_tokens (
598+ " End" , False , token_join , mock_callback )
599+ self .assertFalse (is_thinking )
600+
601+ # Verify final state
602+ self .assertEqual (token_join , ["Start " , " End" ])
603+ self .assertEqual (callback_calls , ["Start " , "Start End" ])
604+
605+ def test_process_thinking_tokens_no_callback (self ):
606+ """Test process_thinking_tokens without callback function"""
607+ token_join = []
608+
609+ is_thinking = _process_thinking_tokens ("Hello" , False , token_join , None )
610+
611+ self .assertFalse (is_thinking )
612+ self .assertEqual (token_join , ["Hello" ])
613+
614+ def test_process_thinking_tokens_empty_token (self ):
615+ """Test process_thinking_tokens with empty token"""
616+ token_join = []
617+ callback_calls = []
618+
619+ def mock_callback (text ):
620+ callback_calls .append (text )
621+
622+ is_thinking = _process_thinking_tokens (
623+ "" , False , token_join , mock_callback )
624+
625+ self .assertFalse (is_thinking )
626+ self .assertEqual (token_join , ["" ])
627+ self .assertEqual (callback_calls , ["" ])
628+
490629
491630if __name__ == '__main__' :
492631 unittest .main ()
0 commit comments