Skip to content

Commit 4f9bdc6

Browse files
committed
Update thinkdeeper.py
1 parent 1bd01f4 commit 4f9bdc6

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

optillm/thinkdeeper.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,18 @@ def reasoning_effort(self, messages) -> str:
8989
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
9090
logits = out.logits[0, -1, :]
9191

92-
# Check if we need to force end token
92+
# Check if we need to force end thinking
9393
force_end = (n_thinking_tokens >= self.config["max_thinking_tokens"] or
9494
self.thought_count >= self.config["max_thoughts"])
9595

96-
if force_end:
96+
if force_end and not seen_end_think:
9797
logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}")
9898
next_token = self.end_think_token
9999
response_chunks.append(self.tokenizer.decode([next_token]))
100-
# Break immediately when forcing end token
101-
break
100+
seen_end_think = True
101+
# Don't break - continue generating but with end_think token forced
102+
tokens = torch.tensor([[next_token]]).to(tokens.device)
103+
continue
102104
else:
103105
next_token = torch.multinomial(
104106
torch.softmax(logits, dim=-1), 1
@@ -107,8 +109,8 @@ def reasoning_effort(self, messages) -> str:
107109
kv = out.past_key_values
108110
next_str = self.tokenizer.decode([next_token])
109111

110-
# Check if this is a thought-switching token
111-
if self.is_thought_switch(next_token):
112+
# Check if this is a thought-switching token (only if not in conclusion phase)
113+
if not seen_end_think and self.is_thought_switch(next_token):
112114
self.thought_count += 1
113115
logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}")
114116
# Clear the sequence after detecting a switch
@@ -135,6 +137,7 @@ def reasoning_effort(self, messages) -> str:
135137
if next_token == self.model.config.eos_token_id:
136138
if seen_end_think:
137139
logger.debug("Reached EOS after end think token - stopping generation")
140+
response_chunks.append(next_str)
138141
break
139142
elif n_thinking_tokens < self.config["min_thinking_tokens"]:
140143
# Continue with thought transition if under minimum tokens
@@ -150,11 +153,13 @@ def reasoning_effort(self, messages) -> str:
150153
# Force end think token if we haven't seen it
151154
logger.debug("Reached EOS without end think token - adding end token")
152155
response_chunks.append(self.tokenizer.decode([self.end_think_token]))
156+
response_chunks.append(next_str)
153157
break
154158

155159
# Normal token processing
156160
response_chunks.append(next_str)
157-
n_thinking_tokens += 1
161+
if not seen_end_think:
162+
n_thinking_tokens += 1
158163
tokens = torch.tensor([[next_token]]).to(tokens.device)
159164

160165
# Join all chunks and add framing tokens

0 commit comments

Comments
 (0)