11import torch
22import random
33from transformers import PreTrainedModel , PreTrainedTokenizer , DynamicCache
4- from typing import Tuple , Dict , Any , List , Optional
5- import numpy as np
4+ from typing import Tuple , Dict , Any , List
65import logging
76
87logger = logging .getLogger (__name__ )
98logger .setLevel (logging .DEBUG )
109
11- def calculate_entropy (logits : torch .Tensor ) -> float :
12- """
13- Calculate entropy from logits tensor
14-
15- Args:
16- logits: Raw logits from model output (pre-softmax)
17-
18- Returns:
19- float: Entropy value
20- """
21- # Convert logits to probabilities using softmax
22- probs = torch .softmax (logits , dim = - 1 )
23-
24- # Calculate entropy: -sum(p_i * log(p_i))
25- # Add small epsilon to avoid log(0)
26- entropy = - torch .sum (probs * torch .log2 (probs + 1e-10 ))
27-
28- return entropy .item ()
29-
30- class EntropyTracker :
31- """Tracks entropy over time and provides analysis capabilities"""
32-
33- def __init__ (self , window_size : int = 10 ):
34- """
35- Initialize entropy tracker
36-
37- Args:
38- window_size: Number of tokens to track for moving averages
39- """
40- self .entropy_history = []
41- self .window_size = window_size
42- self .transition_entropies = {
43- "before" : {},
44- "after" : {}
45- }
46-
47- def add_entropy (self , entropy : float ) -> None :
48- """Add entropy value to history"""
49- self .entropy_history .append (entropy )
50-
51- def get_recent_avg_entropy (self ) -> float :
52- """Get average entropy over recent window"""
53- if len (self .entropy_history ) == 0 :
54- return 0.0
55-
56- window = self .entropy_history [- min (self .window_size , len (self .entropy_history )):]
57- return sum (window ) / len (window )
58-
59- def record_transition_entropy (self , transition_word : str , before : bool = True ) -> None :
60- """
61- Record entropy around a transition word
62-
63- Args:
64- transition_word: The transition word being tracked
65- before: True if recording before transition, False if after
66- """
67- key = "before" if before else "after"
68- if transition_word not in self .transition_entropies [key ]:
69- self .transition_entropies [key ][transition_word ] = []
70-
71- current_entropy = self .get_recent_avg_entropy ()
72- self .transition_entropies [key ][transition_word ].append (current_entropy )
73-
74- def get_entropy_change (self , transition_word : str ) -> float :
75- """
76- Get entropy change for a specific transition word
77-
78- Args:
79- transition_word: The transition to analyze
80-
81- Returns:
82- float: Average entropy change (after - before)
83- """
84- if (transition_word not in self .transition_entropies ["before" ] or
85- transition_word not in self .transition_entropies ["after" ]):
86- return 0.0
87-
88- before_values = self .transition_entropies ["before" ][transition_word ]
89- after_values = self .transition_entropies ["after" ][transition_word ]
90-
91- # Use only entries that have both before and after
92- count = min (len (before_values ), len (after_values ))
93- if count == 0 :
94- return 0.0
95-
96- before_avg = sum (before_values [:count ]) / count
97- after_avg = sum (after_values [:count ]) / count
98-
99- return after_avg - before_avg
100-
101- class InterventionHandler :
102- """Handles detection and injection of verification prompts"""
103-
104- def __init__ (self , entropy_thresholds : Dict [str , float ] = None ):
105- """
106- Initialize intervention handler
107-
108- Args:
109- entropy_thresholds: Dict mapping transition words to threshold values
110- """
111- # Default entropy thresholds based on the logit analysis
112- self .entropy_thresholds = {
113- "However," : 1.45 ,
114- "Wait," : 1.50 ,
115- "Alternatively," : 1.45 ,
116- "Additionally," : 1.40 ,
117- # Default threshold for any other transition
118- "default" : 1.50
119- }
120-
121- # Update with any user-provided thresholds
122- if entropy_thresholds :
123- self .entropy_thresholds .update (entropy_thresholds )
124-
125- # Track interventions to avoid repeating too frequently
126- self .last_intervention_token = 0
127- self .current_token_pos = 0
128- self .min_tokens_between_interventions = 50
129-
130- # Verification prompts tailored to different transition words
131- self .verification_prompts = {
132- "However," : " Let me carefully verify if this contrary point actually changes my conclusion." ,
133- "Wait," : " Let me double-check my previous calculation before changing direction." ,
134- "Alternatively," : " Let me evaluate if this alternative approach is consistent with my previous reasoning." ,
135- "Additionally," : " Let me verify that this additional information is correctly incorporated." ,
136- # Default prompt
137- "default" : " Let me verify my reasoning step by step before continuing."
138- }
139-
140- def increment_token_pos (self ):
141- """Increment current token position counter"""
142- self .current_token_pos += 1
143-
144- def should_intervene (self ,
145- transition_word : str ,
146- current_entropy : float ,
147- tokens_generated : int ,
148- max_tokens : int ) -> bool :
149- """
150- Determine if we should intervene based on current conditions
151-
152- Args:
153- transition_word: The detected transition word
154- current_entropy: Current entropy value
155- tokens_generated: Number of tokens generated so far
156- max_tokens: Maximum tokens allowed
157-
158- Returns:
159- bool: True if should intervene, False otherwise
160- """
161- # Get appropriate threshold
162- threshold = self .entropy_thresholds .get (transition_word , self .entropy_thresholds ["default" ])
163-
164- # Check if entropy exceeds threshold
165- entropy_condition = current_entropy > threshold
166-
167- # Check if we're in the middle-to-late reasoning phase (40-80% of generation)
168- generation_progress = tokens_generated / max_tokens if max_tokens > 0 else 0.5
169- progress_condition = 0.4 < generation_progress < 0.8
170-
171- # Ensure we don't intervene too frequently
172- frequency_condition = (self .current_token_pos - self .last_intervention_token ) > self .min_tokens_between_interventions
173-
174- # Determine if we should intervene
175- should_intervene = entropy_condition and progress_condition and frequency_condition
176-
177- if should_intervene :
178- self .last_intervention_token = self .current_token_pos
179-
180- return should_intervene
181-
182- def get_verification_prompt (self , transition_word : str ) -> str :
183- """
184- Get appropriate verification prompt for the transition word
185-
186- Args:
187- transition_word: The transition word to get prompt for
188-
189- Returns:
190- str: The verification prompt
191- """
192- return self .verification_prompts .get (transition_word , self .verification_prompts ["default" ])
193-
194-
19510DEFAULT_CONFIG = {
196- "min_thinking_tokens" : 512 ,
197- "max_thinking_tokens" : 2048 , # New parameter to cap thinking length
198- "max_thoughts" : 4 , # New parameter to limit number of thought transitions
11+ "min_thinking_tokens" : 1024 ,
12+ "max_thinking_tokens" : 4196 ,
13+ "max_thoughts" : 64 ,
19914 "prefill" : "" ,
20015 "start_think_token" : "<think>" ,
20116 "end_think_token" : "</think>" ,
@@ -217,14 +32,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
21732 self ._start_think_token = start_tokens [0 ] if len (start_tokens ) == 1 else start_tokens [1 ]
21833 self .end_think_token = end_tokens [0 ] if len (end_tokens ) == 1 else end_tokens [1 ]
21934
220- # Store thought switch markers as token sequences and their decoded forms
35+ # Store thought switch markers as token sequences
22136 self .thought_switch_sequences = []
222- self .thought_switch_phrases = []
22337 for phrase in self .config ["thought_switch_tokens" ]:
22438 # Encode without adding special tokens to get exact sequence
22539 token_ids = self .tokenizer .encode (phrase , add_special_tokens = False )
22640 self .thought_switch_sequences .append (token_ids )
227- self .thought_switch_phrases .append (phrase )
22841 logger .debug (f"Encoded '{ phrase } ' to token sequence: { token_ids } " )
22942 logger .debug (f"Decoded back: { self .tokenizer .decode (token_ids )} " )
23043
@@ -233,32 +46,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
23346 self .current_sequence = [] # Track recent tokens for sequence matching
23447 self .max_sequence_length = max (len (seq ) for seq in self .thought_switch_sequences )
23548
236- # Initialize entropy tracking
237- self .entropy_tracker = EntropyTracker ()
238-
239- # Initialize intervention handler
240- entropy_thresholds = self .config .get ("entropy_thresholds" , None )
241- self .intervention_handler = InterventionHandler (entropy_thresholds )
242-
243- # Track if we're currently in an intervention
244- self .in_intervention = False
245- self .current_intervention_tokens = []
246-
247- # Map token sequences to their phrases
248- self .sequence_to_phrase = {}
249- for phrase , sequence in zip (self .thought_switch_phrases , self .thought_switch_sequences ):
250- seq_tuple = tuple (sequence )
251- self .sequence_to_phrase [seq_tuple ] = phrase
49+ for phrase , sequence in zip (self .config ["thought_switch_tokens" ], self .thought_switch_sequences ):
25250 logger .debug (f"Thought switch marker '{ phrase } ' encoded as: { sequence } " )
25351 logger .debug (f"Decoded back as: { self .tokenizer .decode (sequence )} " )
25452
255- def is_thought_switch (self , token : int ) -> Tuple [bool , Optional [str ]]:
256- """
257- Check if adding this token creates a thought switch sequence.
258-
259- Returns:
260- Tuple[bool, Optional[str]]: (is_switch, transition_phrase)
261- """
53+ def is_thought_switch (self , token : int ) -> bool :
54+ """Check if adding this token creates a thought switch sequence."""
26255 # Add new token to current sequence
26356 self .current_sequence .append (token )
26457
@@ -267,16 +60,16 @@ def is_thought_switch(self, token: int) -> Tuple[bool, Optional[str]]:
26760 self .current_sequence = self .current_sequence [- self .max_sequence_length :]
26861
26962 # Check if current sequence ends with any thought switch sequence
270- for i , sequence in enumerate ( self .thought_switch_sequences ) :
63+ for sequence in self .thought_switch_sequences :
27164 if len (sequence ) <= len (self .current_sequence ) and \
27265 self .current_sequence [- len (sequence ):] == sequence :
273- return True , self . thought_switch_phrases [ i ]
66+ return True
27467
275- return False , None
68+ return False
27669
27770 @torch .inference_mode ()
27871 def reasoning_effort (self , messages ) -> str :
279- """Generate response with ThinkDeeper's controlled thinking process and entropy-based interventions """
72+ """Generate response with ThinkDeeper's controlled thinking process"""
28073
28174 messages .append ({"role" : "assistant" , "content" : f"{ self .config ['start_think_token' ]} \n { self .config ['prefill' ]} " })
28275
@@ -292,25 +85,10 @@ def reasoning_effort(self, messages) -> str:
29285 seen_end_think = False
29386 response_chunks = []
29487
295- # Reset tracking for new generation
296- self .thought_count = 0
297- self .current_sequence = []
298- self .in_intervention = False
299- self .current_intervention_tokens = []
300- self .intervention_handler .last_intervention_token = 0
301- self .intervention_handler .current_token_pos = 0
302-
30388 while True :
30489 out = self .model (input_ids = tokens , past_key_values = kv , use_cache = True )
30590 logits = out .logits [0 , - 1 , :]
30691
307- # Calculate entropy and update tracker
308- current_entropy = calculate_entropy (logits )
309- self .entropy_tracker .add_entropy (current_entropy )
310-
311- # Update the token position counter
312- self .intervention_handler .increment_token_pos ()
313-
31492 # Check if we need to force end thinking
31593 force_end = (n_thinking_tokens >= self .config ["max_thinking_tokens" ] or
31694 self .thought_count >= self .config ["max_thoughts" ])
@@ -324,57 +102,19 @@ def reasoning_effort(self, messages) -> str:
324102 tokens = torch .tensor ([[next_token ]]).to (tokens .device )
325103 continue
326104 else :
327- # If we're in an intervention, continue with it
328- if self .in_intervention and self .current_intervention_tokens :
329- next_token = self .current_intervention_tokens .pop (0 )
330- logger .debug (f"Continuing intervention with token: { self .tokenizer .decode ([next_token ])} " )
331-
332- # If we're done with the intervention, mark it as complete
333- if not self .current_intervention_tokens :
334- self .in_intervention = False
335- logger .debug ("Intervention complete" )
336- else :
337- # Normal generation
338- next_token = torch .multinomial (
339- torch .softmax (logits , dim = - 1 ), 1
340- ).item ()
105+ next_token = torch .multinomial (
106+ torch .softmax (logits , dim = - 1 ), 1
107+ ).item ()
341108
342109 kv = out .past_key_values
343110 next_str = self .tokenizer .decode ([next_token ])
344111
345112 # Check if this is a thought-switching token (only if not in conclusion phase)
346- if not seen_end_think :
347- is_switch , transition_word = self .is_thought_switch (next_token )
348- if is_switch :
349- # Record entropy before transition
350- self .entropy_tracker .record_transition_entropy (transition_word , before = True )
351-
352- self .thought_count += 1
353- logger .debug (f"Detected thought switch marker '{ transition_word } '. Total thoughts: { self .thought_count } " )
354-
355- # Decide if we should intervene at this transition
356- should_intervene = self .intervention_handler .should_intervene (
357- transition_word ,
358- current_entropy ,
359- n_thinking_tokens ,
360- self .config ["max_thinking_tokens" ]
361- )
362-
363- if should_intervene and not self .in_intervention :
364- # Get verification prompt
365- verification_prompt = self .intervention_handler .get_verification_prompt (transition_word )
366- logger .debug (f"Intervening after '{ transition_word } ' with prompt: { verification_prompt } " )
367-
368- # Tokenize the verification prompt and set up for injection
369- verification_tokens = self .tokenizer .encode (verification_prompt , add_special_tokens = False )
370- self .current_intervention_tokens = verification_tokens
371- self .in_intervention = True
372-
373- # Add the verification prompt to response
374- response_chunks .append (verification_prompt )
375-
376- # Record entropy after transition (will be in next token, but this helps with tracking)
377- self .entropy_tracker .record_transition_entropy (transition_word , before = False )
113+ if not seen_end_think and self .is_thought_switch (next_token ):
114+ self .thought_count += 1
115+ logger .debug (f"Detected thought switch marker. Total thoughts: { self .thought_count } " )
116+ # Clear the sequence after detecting a switch
117+ self .current_sequence = []
378118
379119 # Handle natural end think token
380120 if next_token == self .end_think_token :
@@ -418,20 +158,11 @@ def reasoning_effort(self, messages) -> str:
418158 seen_end_think = True
419159 continue
420160
421- # Skip adding token to response if we're injecting an intervention
422- if not self .in_intervention or not self .current_intervention_tokens :
423- response_chunks .append (next_str )
424-
161+ # Normal token processing
162+ response_chunks .append (next_str )
425163 if not seen_end_think :
426164 n_thinking_tokens += 1
427-
428- # Set up next token
429- if self .in_intervention and self .current_intervention_tokens :
430- # Next token is from intervention
431- tokens = torch .tensor ([[self .current_intervention_tokens [0 ]]]).to (tokens .device )
432- else :
433- # Normal next token
434- tokens = torch .tensor ([[next_token ]]).to (tokens .device )
165+ tokens = torch .tensor ([[next_token ]]).to (tokens .device )
435166
436167 # Join all chunks and add framing tokens
437168 response = "" .join (response_chunks )
0 commit comments