55import logging
66
77logger = logging .getLogger (__name__ )
8+ logger .setLevel (logging .DEBUG )
89
910DEFAULT_CONFIG = {
1011 "min_thinking_tokens" : 512 ,
@@ -31,14 +32,40 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
3132 self ._start_think_token = start_tokens [0 ] if len (start_tokens ) == 1 else start_tokens [1 ]
3233 self .end_think_token = end_tokens [0 ] if len (end_tokens ) == 1 else end_tokens [1 ]
3334
34- # Get token IDs for thought switching indicators
35- self .thought_switch_tokens = set ()
35+ # Store thought switch markers as token sequences
36+ self .thought_switch_sequences = []
3637 for phrase in self .config ["thought_switch_tokens" ]:
38+ # Encode without adding special tokens to get exact sequence
3739 token_ids = self .tokenizer .encode (phrase , add_special_tokens = False )
38- self .thought_switch_tokens .update (token_ids )
40+ self .thought_switch_sequences .append (token_ids )
41+ logger .debug (f"Encoded '{ phrase } ' to token sequence: { token_ids } " )
42+ logger .debug (f"Decoded back: { self .tokenizer .decode (token_ids )} " )
3943
4044 # Track thought switches
4145 self .thought_count = 0
46+ self .current_sequence = [] # Track recent tokens for sequence matching
47+ self .max_sequence_length = max (len (seq ) for seq in self .thought_switch_sequences )
48+
49+ for phrase , sequence in zip (self .config ["thought_switch_tokens" ], self .thought_switch_sequences ):
50+ logger .debug (f"Thought switch marker '{ phrase } ' encoded as: { sequence } " )
51+ logger .debug (f"Decoded back as: { self .tokenizer .decode (sequence )} " )
52+
53+ def is_thought_switch (self , token : int ) -> bool :
54+ """Check if adding this token creates a thought switch sequence."""
55+ # Add new token to current sequence
56+ self .current_sequence .append (token )
57+
58+ # Keep only the most recent tokens that could match our sequences
59+ if len (self .current_sequence ) > self .max_sequence_length :
60+ self .current_sequence = self .current_sequence [- self .max_sequence_length :]
61+
62+ # Check if current sequence ends with any thought switch sequence
63+ for sequence in self .thought_switch_sequences :
64+ if len (sequence ) <= len (self .current_sequence ) and \
65+ self .current_sequence [- len (sequence ):] == sequence :
66+ return True
67+
68+ return False
4269
4370 @torch .inference_mode ()
4471 def reasoning_effort (self , messages ) -> str :
@@ -62,11 +89,16 @@ def reasoning_effort(self, messages) -> str:
6289 out = self .model (input_ids = tokens , past_key_values = kv , use_cache = True )
6390 logits = out .logits [0 , - 1 , :]
6491
65- # Force end think token if we exceed limits
66- if (n_thinking_tokens >= self .config ["max_thinking_tokens" ] or
67- self .thought_count >= self .config ["max_thoughts" ]):
68- next_token = self .end_think_token
92+ # Check if we need to force end token
93+ force_end = (n_thinking_tokens >= self .config ["max_thinking_tokens" ] or
94+ self .thought_count >= self .config ["max_thoughts" ])
95+
96+ if force_end :
6997 logger .debug (f"Forcing end think token. Tokens: { n_thinking_tokens } , Thoughts: { self .thought_count } " )
98+ next_token = self .end_think_token
99+ response_chunks .append (self .tokenizer .decode ([next_token ]))
100+ # Break immediately when forcing end token
101+ break
70102 else :
71103 next_token = torch .multinomial (
72104 torch .softmax (logits , dim = - 1 ), 1
@@ -76,42 +108,56 @@ def reasoning_effort(self, messages) -> str:
76108 next_str = self .tokenizer .decode ([next_token ])
77109
78110 # Check if this is a thought-switching token
79- if next_token in self .thought_switch_tokens :
111+ if self .is_thought_switch ( next_token ) :
80112 self .thought_count += 1
81- logger .debug (f"Detected thought switch. Total thoughts: { self .thought_count } " )
113+ logger .debug (f"Detected thought switch marker. Total thoughts: { self .thought_count } " )
114+ # Clear the sequence after detecting a switch
115+ self .current_sequence = []
82116
83- # Track if we've seen the end think token
117+ # Handle natural end think token
84118 if next_token == self .end_think_token :
85119 seen_end_think = True
86120 logger .debug ("Found end think token" )
87-
88- # Need to continue generating if:
89- # 1. We hit end think/eos before min tokens OR
90- # 2. We hit eos without seeing end think token
91- if ((next_token in (self .end_think_token , self .model .config .eos_token_id )
92- and n_thinking_tokens < self .config ["min_thinking_tokens" ])
93- or (next_token == self .model .config .eos_token_id and not seen_end_think )):
94-
95- # Insert thought transition
96- replacement = random .choice (self .config ["thought_switch_tokens" ])
97- logger .debug (f"Inserting thought transition: '{ replacement } ' (tokens: { n_thinking_tokens } )" )
98- response_chunks .append (replacement )
99- replacement_tokens = self .tokenizer .encode (replacement )
100- n_thinking_tokens += len (replacement_tokens )
101- tokens = torch .tensor ([replacement_tokens ]).to (tokens .device )
102- self .thought_count += 1
103- seen_end_think = False
104-
105- elif next_token == self .model .config .eos_token_id and seen_end_think :
106- logger .debug ("Reached EOS after end think token - stopping generation" )
107- break
108121
109- else :
110- response_chunks .append (next_str )
111- n_thinking_tokens += 1
112- tokens = torch .tensor ([[next_token ]]).to (tokens .device )
122+ # If we haven't reached minimum tokens, continue with thought transition
123+ if n_thinking_tokens < self .config ["min_thinking_tokens" ]:
124+ replacement = random .choice (self .config ["thought_switch_tokens" ])
125+ logger .debug (f"Inserting thought transition: '{ replacement } ' (tokens: { n_thinking_tokens } )" )
126+ response_chunks .append (replacement )
127+ replacement_tokens = self .tokenizer .encode (replacement )
128+ n_thinking_tokens += len (replacement_tokens )
129+ tokens = torch .tensor ([replacement_tokens ]).to (tokens .device )
130+ self .thought_count += 1
131+ seen_end_think = False
132+ continue
133+
134+ # Handle EOS token
135+ if next_token == self .model .config .eos_token_id :
136+ if seen_end_think :
137+ logger .debug ("Reached EOS after end think token - stopping generation" )
138+ break
139+ elif n_thinking_tokens < self .config ["min_thinking_tokens" ]:
140+ # Continue with thought transition if under minimum tokens
141+ replacement = random .choice (self .config ["thought_switch_tokens" ])
142+ logger .debug (f"Inserting thought transition: '{ replacement } ' (tokens: { n_thinking_tokens } )" )
143+ response_chunks .append (replacement )
144+ replacement_tokens = self .tokenizer .encode (replacement )
145+ n_thinking_tokens += len (replacement_tokens )
146+ tokens = torch .tensor ([replacement_tokens ]).to (tokens .device )
147+ self .thought_count += 1
148+ continue
149+ else :
150+ # Force end think token if we haven't seen it
151+ logger .debug ("Reached EOS without end think token - adding end token" )
152+ response_chunks .append (self .tokenizer .decode ([self .end_think_token ]))
153+ break
154+
155+ # Normal token processing
156+ response_chunks .append (next_str )
157+ n_thinking_tokens += 1
158+ tokens = torch .tensor ([[next_token ]]).to (tokens .device )
113159
114- # Join all chunks and trim off the initial prompt
160+ # Join all chunks and add framing tokens
115161 response = "" .join (response_chunks )
116162 full_response = f"{ self .config ['start_think_token' ]} \n { self .config ['prefill' ]} { response } "
117163
0 commit comments