Skip to content

Commit 3a7d0f0

Browse files
committed
Update thinkdeeper.py
1 parent 718b678 commit 3a7d0f0

File tree

1 file changed

+23
-292
lines changed

1 file changed

+23
-292
lines changed

optillm/thinkdeeper.py

Lines changed: 23 additions & 292 deletions
Original file line numberDiff line numberDiff line change
@@ -1,201 +1,16 @@
11
import torch
22
import random
33
from transformers import PreTrainedModel, PreTrainedTokenizer, DynamicCache
4-
from typing import Tuple, Dict, Any, List, Optional
5-
import numpy as np
4+
from typing import Tuple, Dict, Any, List
65
import logging
76

87
logger = logging.getLogger(__name__)
98
logger.setLevel(logging.DEBUG)
109

11-
def calculate_entropy(logits: torch.Tensor) -> float:
12-
"""
13-
Calculate entropy from logits tensor
14-
15-
Args:
16-
logits: Raw logits from model output (pre-softmax)
17-
18-
Returns:
19-
float: Entropy value
20-
"""
21-
# Convert logits to probabilities using softmax
22-
probs = torch.softmax(logits, dim=-1)
23-
24-
# Calculate entropy: -sum(p_i * log(p_i))
25-
# Add small epsilon to avoid log(0)
26-
entropy = -torch.sum(probs * torch.log2(probs + 1e-10))
27-
28-
return entropy.item()
29-
30-
class EntropyTracker:
31-
"""Tracks entropy over time and provides analysis capabilities"""
32-
33-
def __init__(self, window_size: int = 10):
34-
"""
35-
Initialize entropy tracker
36-
37-
Args:
38-
window_size: Number of tokens to track for moving averages
39-
"""
40-
self.entropy_history = []
41-
self.window_size = window_size
42-
self.transition_entropies = {
43-
"before": {},
44-
"after": {}
45-
}
46-
47-
def add_entropy(self, entropy: float) -> None:
48-
"""Add entropy value to history"""
49-
self.entropy_history.append(entropy)
50-
51-
def get_recent_avg_entropy(self) -> float:
52-
"""Get average entropy over recent window"""
53-
if len(self.entropy_history) == 0:
54-
return 0.0
55-
56-
window = self.entropy_history[-min(self.window_size, len(self.entropy_history)):]
57-
return sum(window) / len(window)
58-
59-
def record_transition_entropy(self, transition_word: str, before: bool = True) -> None:
60-
"""
61-
Record entropy around a transition word
62-
63-
Args:
64-
transition_word: The transition word being tracked
65-
before: True if recording before transition, False if after
66-
"""
67-
key = "before" if before else "after"
68-
if transition_word not in self.transition_entropies[key]:
69-
self.transition_entropies[key][transition_word] = []
70-
71-
current_entropy = self.get_recent_avg_entropy()
72-
self.transition_entropies[key][transition_word].append(current_entropy)
73-
74-
def get_entropy_change(self, transition_word: str) -> float:
75-
"""
76-
Get entropy change for a specific transition word
77-
78-
Args:
79-
transition_word: The transition to analyze
80-
81-
Returns:
82-
float: Average entropy change (after - before)
83-
"""
84-
if (transition_word not in self.transition_entropies["before"] or
85-
transition_word not in self.transition_entropies["after"]):
86-
return 0.0
87-
88-
before_values = self.transition_entropies["before"][transition_word]
89-
after_values = self.transition_entropies["after"][transition_word]
90-
91-
# Use only entries that have both before and after
92-
count = min(len(before_values), len(after_values))
93-
if count == 0:
94-
return 0.0
95-
96-
before_avg = sum(before_values[:count]) / count
97-
after_avg = sum(after_values[:count]) / count
98-
99-
return after_avg - before_avg
100-
101-
class InterventionHandler:
102-
"""Handles detection and injection of verification prompts"""
103-
104-
def __init__(self, entropy_thresholds: Dict[str, float] = None):
105-
"""
106-
Initialize intervention handler
107-
108-
Args:
109-
entropy_thresholds: Dict mapping transition words to threshold values
110-
"""
111-
# Default entropy thresholds based on the logit analysis
112-
self.entropy_thresholds = {
113-
"However,": 1.45,
114-
"Wait,": 1.50,
115-
"Alternatively,": 1.45,
116-
"Additionally,": 1.40,
117-
# Default threshold for any other transition
118-
"default": 1.50
119-
}
120-
121-
# Update with any user-provided thresholds
122-
if entropy_thresholds:
123-
self.entropy_thresholds.update(entropy_thresholds)
124-
125-
# Track interventions to avoid repeating too frequently
126-
self.last_intervention_token = 0
127-
self.current_token_pos = 0
128-
self.min_tokens_between_interventions = 50
129-
130-
# Verification prompts tailored to different transition words
131-
self.verification_prompts = {
132-
"However,": " Let me carefully verify if this contrary point actually changes my conclusion.",
133-
"Wait,": " Let me double-check my previous calculation before changing direction.",
134-
"Alternatively,": " Let me evaluate if this alternative approach is consistent with my previous reasoning.",
135-
"Additionally,": " Let me verify that this additional information is correctly incorporated.",
136-
# Default prompt
137-
"default": " Let me verify my reasoning step by step before continuing."
138-
}
139-
140-
def increment_token_pos(self):
141-
"""Increment current token position counter"""
142-
self.current_token_pos += 1
143-
144-
def should_intervene(self,
145-
transition_word: str,
146-
current_entropy: float,
147-
tokens_generated: int,
148-
max_tokens: int) -> bool:
149-
"""
150-
Determine if we should intervene based on current conditions
151-
152-
Args:
153-
transition_word: The detected transition word
154-
current_entropy: Current entropy value
155-
tokens_generated: Number of tokens generated so far
156-
max_tokens: Maximum tokens allowed
157-
158-
Returns:
159-
bool: True if should intervene, False otherwise
160-
"""
161-
# Get appropriate threshold
162-
threshold = self.entropy_thresholds.get(transition_word, self.entropy_thresholds["default"])
163-
164-
# Check if entropy exceeds threshold
165-
entropy_condition = current_entropy > threshold
166-
167-
# Check if we're in the middle-to-late reasoning phase (40-80% of generation)
168-
generation_progress = tokens_generated / max_tokens if max_tokens > 0 else 0.5
169-
progress_condition = 0.4 < generation_progress < 0.8
170-
171-
# Ensure we don't intervene too frequently
172-
frequency_condition = (self.current_token_pos - self.last_intervention_token) > self.min_tokens_between_interventions
173-
174-
# Determine if we should intervene
175-
should_intervene = entropy_condition and progress_condition and frequency_condition
176-
177-
if should_intervene:
178-
self.last_intervention_token = self.current_token_pos
179-
180-
return should_intervene
181-
182-
def get_verification_prompt(self, transition_word: str) -> str:
183-
"""
184-
Get appropriate verification prompt for the transition word
185-
186-
Args:
187-
transition_word: The transition word to get prompt for
188-
189-
Returns:
190-
str: The verification prompt
191-
"""
192-
return self.verification_prompts.get(transition_word, self.verification_prompts["default"])
193-
194-
19510
DEFAULT_CONFIG = {
196-
"min_thinking_tokens": 512,
197-
"max_thinking_tokens": 2048, # New parameter to cap thinking length
198-
"max_thoughts": 4, # New parameter to limit number of thought transitions
11+
"min_thinking_tokens": 1024,
12+
"max_thinking_tokens": 4196,
13+
"max_thoughts": 64,
19914
"prefill": "",
20015
"start_think_token": "<think>",
20116
"end_think_token": "</think>",
@@ -217,14 +32,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
21732
self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1]
21833
self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1]
21934

220-
# Store thought switch markers as token sequences and their decoded forms
35+
# Store thought switch markers as token sequences
22136
self.thought_switch_sequences = []
222-
self.thought_switch_phrases = []
22337
for phrase in self.config["thought_switch_tokens"]:
22438
# Encode without adding special tokens to get exact sequence
22539
token_ids = self.tokenizer.encode(phrase, add_special_tokens=False)
22640
self.thought_switch_sequences.append(token_ids)
227-
self.thought_switch_phrases.append(phrase)
22841
logger.debug(f"Encoded '{phrase}' to token sequence: {token_ids}")
22942
logger.debug(f"Decoded back: {self.tokenizer.decode(token_ids)}")
23043

@@ -233,32 +46,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
23346
self.current_sequence = [] # Track recent tokens for sequence matching
23447
self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences)
23548

236-
# Initialize entropy tracking
237-
self.entropy_tracker = EntropyTracker()
238-
239-
# Initialize intervention handler
240-
entropy_thresholds = self.config.get("entropy_thresholds", None)
241-
self.intervention_handler = InterventionHandler(entropy_thresholds)
242-
243-
# Track if we're currently in an intervention
244-
self.in_intervention = False
245-
self.current_intervention_tokens = []
246-
247-
# Map token sequences to their phrases
248-
self.sequence_to_phrase = {}
249-
for phrase, sequence in zip(self.thought_switch_phrases, self.thought_switch_sequences):
250-
seq_tuple = tuple(sequence)
251-
self.sequence_to_phrase[seq_tuple] = phrase
49+
for phrase, sequence in zip(self.config["thought_switch_tokens"], self.thought_switch_sequences):
25250
logger.debug(f"Thought switch marker '{phrase}' encoded as: {sequence}")
25351
logger.debug(f"Decoded back as: {self.tokenizer.decode(sequence)}")
25452

255-
def is_thought_switch(self, token: int) -> Tuple[bool, Optional[str]]:
256-
"""
257-
Check if adding this token creates a thought switch sequence.
258-
259-
Returns:
260-
Tuple[bool, Optional[str]]: (is_switch, transition_phrase)
261-
"""
53+
def is_thought_switch(self, token: int) -> bool:
54+
"""Check if adding this token creates a thought switch sequence."""
26255
# Add new token to current sequence
26356
self.current_sequence.append(token)
26457

@@ -267,16 +60,16 @@ def is_thought_switch(self, token: int) -> Tuple[bool, Optional[str]]:
26760
self.current_sequence = self.current_sequence[-self.max_sequence_length:]
26861

26962
# Check if current sequence ends with any thought switch sequence
270-
for i, sequence in enumerate(self.thought_switch_sequences):
63+
for sequence in self.thought_switch_sequences:
27164
if len(sequence) <= len(self.current_sequence) and \
27265
self.current_sequence[-len(sequence):] == sequence:
273-
return True, self.thought_switch_phrases[i]
66+
return True
27467

275-
return False, None
68+
return False
27669

27770
@torch.inference_mode()
27871
def reasoning_effort(self, messages) -> str:
279-
"""Generate response with ThinkDeeper's controlled thinking process and entropy-based interventions"""
72+
"""Generate response with ThinkDeeper's controlled thinking process"""
28073

28174
messages.append({"role": "assistant", "content": f"{self.config['start_think_token']}\n{self.config['prefill']}"})
28275

@@ -292,25 +85,10 @@ def reasoning_effort(self, messages) -> str:
29285
seen_end_think = False
29386
response_chunks = []
29487

295-
# Reset tracking for new generation
296-
self.thought_count = 0
297-
self.current_sequence = []
298-
self.in_intervention = False
299-
self.current_intervention_tokens = []
300-
self.intervention_handler.last_intervention_token = 0
301-
self.intervention_handler.current_token_pos = 0
302-
30388
while True:
30489
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
30590
logits = out.logits[0, -1, :]
30691

307-
# Calculate entropy and update tracker
308-
current_entropy = calculate_entropy(logits)
309-
self.entropy_tracker.add_entropy(current_entropy)
310-
311-
# Update the token position counter
312-
self.intervention_handler.increment_token_pos()
313-
31492
# Check if we need to force end thinking
31593
force_end = (n_thinking_tokens >= self.config["max_thinking_tokens"] or
31694
self.thought_count >= self.config["max_thoughts"])
@@ -324,57 +102,19 @@ def reasoning_effort(self, messages) -> str:
324102
tokens = torch.tensor([[next_token]]).to(tokens.device)
325103
continue
326104
else:
327-
# If we're in an intervention, continue with it
328-
if self.in_intervention and self.current_intervention_tokens:
329-
next_token = self.current_intervention_tokens.pop(0)
330-
logger.debug(f"Continuing intervention with token: {self.tokenizer.decode([next_token])}")
331-
332-
# If we're done with the intervention, mark it as complete
333-
if not self.current_intervention_tokens:
334-
self.in_intervention = False
335-
logger.debug("Intervention complete")
336-
else:
337-
# Normal generation
338-
next_token = torch.multinomial(
339-
torch.softmax(logits, dim=-1), 1
340-
).item()
105+
next_token = torch.multinomial(
106+
torch.softmax(logits, dim=-1), 1
107+
).item()
341108

342109
kv = out.past_key_values
343110
next_str = self.tokenizer.decode([next_token])
344111

345112
# Check if this is a thought-switching token (only if not in conclusion phase)
346-
if not seen_end_think:
347-
is_switch, transition_word = self.is_thought_switch(next_token)
348-
if is_switch:
349-
# Record entropy before transition
350-
self.entropy_tracker.record_transition_entropy(transition_word, before=True)
351-
352-
self.thought_count += 1
353-
logger.debug(f"Detected thought switch marker '{transition_word}'. Total thoughts: {self.thought_count}")
354-
355-
# Decide if we should intervene at this transition
356-
should_intervene = self.intervention_handler.should_intervene(
357-
transition_word,
358-
current_entropy,
359-
n_thinking_tokens,
360-
self.config["max_thinking_tokens"]
361-
)
362-
363-
if should_intervene and not self.in_intervention:
364-
# Get verification prompt
365-
verification_prompt = self.intervention_handler.get_verification_prompt(transition_word)
366-
logger.debug(f"Intervening after '{transition_word}' with prompt: {verification_prompt}")
367-
368-
# Tokenize the verification prompt and set up for injection
369-
verification_tokens = self.tokenizer.encode(verification_prompt, add_special_tokens=False)
370-
self.current_intervention_tokens = verification_tokens
371-
self.in_intervention = True
372-
373-
# Add the verification prompt to response
374-
response_chunks.append(verification_prompt)
375-
376-
# Record entropy after transition (will be in next token, but this helps with tracking)
377-
self.entropy_tracker.record_transition_entropy(transition_word, before=False)
113+
if not seen_end_think and self.is_thought_switch(next_token):
114+
self.thought_count += 1
115+
logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}")
116+
# Clear the sequence after detecting a switch
117+
self.current_sequence = []
378118

379119
# Handle natural end think token
380120
if next_token == self.end_think_token:
@@ -418,20 +158,11 @@ def reasoning_effort(self, messages) -> str:
418158
seen_end_think = True
419159
continue
420160

421-
# Skip adding token to response if we're injecting an intervention
422-
if not self.in_intervention or not self.current_intervention_tokens:
423-
response_chunks.append(next_str)
424-
161+
# Normal token processing
162+
response_chunks.append(next_str)
425163
if not seen_end_think:
426164
n_thinking_tokens += 1
427-
428-
# Set up next token
429-
if self.in_intervention and self.current_intervention_tokens:
430-
# Next token is from intervention
431-
tokens = torch.tensor([[self.current_intervention_tokens[0]]]).to(tokens.device)
432-
else:
433-
# Normal next token
434-
tokens = torch.tensor([[next_token]]).to(tokens.device)
165+
tokens = torch.tensor([[next_token]]).to(tokens.device)
435166

436167
# Join all chunks and add framing tokens
437168
response = "".join(response_chunks)

0 commit comments

Comments
 (0)