@@ -89,16 +89,18 @@ def reasoning_effort(self, messages) -> str:
8989 out = self .model (input_ids = tokens , past_key_values = kv , use_cache = True )
9090 logits = out .logits [0 , - 1 , :]
9191
92- # Check if we need to force end token
92+ # Check if we need to force end thinking
9393 force_end = (n_thinking_tokens >= self .config ["max_thinking_tokens" ] or
9494 self .thought_count >= self .config ["max_thoughts" ])
9595
96- if force_end :
96+ if force_end and not seen_end_think :
9797 logger .debug (f"Forcing end think token. Tokens: { n_thinking_tokens } , Thoughts: { self .thought_count } " )
9898 next_token = self .end_think_token
9999 response_chunks .append (self .tokenizer .decode ([next_token ]))
100- # Break immediately when forcing end token
101- break
100+ seen_end_think = True
101+ # Don't break - continue generating but with end_think token forced
102+ tokens = torch .tensor ([[next_token ]]).to (tokens .device )
103+ continue
102104 else :
103105 next_token = torch .multinomial (
104106 torch .softmax (logits , dim = - 1 ), 1
@@ -107,8 +109,8 @@ def reasoning_effort(self, messages) -> str:
107109 kv = out .past_key_values
108110 next_str = self .tokenizer .decode ([next_token ])
109111
110- # Check if this is a thought-switching token
111- if self .is_thought_switch (next_token ):
112+ # Check if this is a thought-switching token (only if not in conclusion phase)
113+ if not seen_end_think and self .is_thought_switch (next_token ):
112114 self .thought_count += 1
113115 logger .debug (f"Detected thought switch marker. Total thoughts: { self .thought_count } " )
114116 # Clear the sequence after detecting a switch
@@ -135,6 +137,7 @@ def reasoning_effort(self, messages) -> str:
135137 if next_token == self .model .config .eos_token_id :
136138 if seen_end_think :
137139 logger .debug ("Reached EOS after end think token - stopping generation" )
140+ response_chunks .append (next_str )
138141 break
139142 elif n_thinking_tokens < self .config ["min_thinking_tokens" ]:
140143 # Continue with thought transition if under minimum tokens
@@ -150,11 +153,13 @@ def reasoning_effort(self, messages) -> str:
150153 # Force end think token if we haven't seen it
151154 logger .debug ("Reached EOS without end think token - adding end token" )
152155 response_chunks .append (self .tokenizer .decode ([self .end_think_token ]))
156+ response_chunks .append (next_str )
153157 break
154158
155159 # Normal token processing
156160 response_chunks .append (next_str )
157- n_thinking_tokens += 1
161+ if not seen_end_think :
162+ n_thinking_tokens += 1
158163 tokens = torch .tensor ([[next_token ]]).to (tokens .device )
159164
160165 # Join all chunks and add framing tokens
0 commit comments