66
77logger = logging .getLogger (__name__ )
88
9- # Default configurations
109DEFAULT_CONFIG = {
1110 "min_thinking_tokens" : 512 ,
11+ "max_thinking_tokens" : 2048 , # New parameter to cap thinking length
12+ "max_thoughts" : 4 , # New parameter to limit number of thought transitions
1213 "prefill" : "" ,
1314 "start_think_token" : "<think>" ,
1415 "end_think_token" : "</think>" ,
15-
16- # Combined thought transition markers and TIP configs
17- "tip_alpha" : 4.0 , # Penalty strength
18- "tip_beta" : 1024 , # Penalty duration (number of tokens)
1916 "thought_switch_tokens" : [
2017 "Wait," ,
2118 "Alternatively," ,
2219 ],
2320}
2421
25- class ThinkDeeperTIPProcessor :
22+ class ThinkDeeperProcessor :
2623 def __init__ (self , config : Dict [str , Any ], tokenizer , model ):
2724 self .config = {** DEFAULT_CONFIG , ** config }
2825 self .tokenizer = tokenizer
@@ -40,27 +37,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
4037 token_ids = self .tokenizer .encode (phrase , add_special_tokens = False )
4138 self .thought_switch_tokens .update (token_ids )
4239
43- # Track when the last thought switch occurred
44- self .last_thought_switch_pos = 0
45-
46- def adjust_logits_with_tip (self , logits : torch .Tensor , current_pos : int ) -> torch .Tensor :
47- """Apply Thought Switching Penalty (TIP) to logits"""
48- tokens_since_last_switch = current_pos - self .last_thought_switch_pos
49-
50- if tokens_since_last_switch < self .config ["tip_beta" ]:
51- penalty_mask = torch .zeros_like (logits )
52- for token_id in self .thought_switch_tokens :
53- if token_id < logits .size (- 1 ): # Ensure token_id is within valid range
54- penalty_mask [token_id ] = self .config ["tip_alpha" ]
55-
56- adjusted_logits = logits - penalty_mask
57- return adjusted_logits
40+ # Track thought switches
41+ self .thought_count = 0
5842
59- return logits
60-
6143 @torch .inference_mode ()
6244 def reasoning_effort (self , messages ) -> str :
63- """Generate response with ThinkDeeper + TIP """
45+ """Generate response with ThinkDeeper's controlled thinking process """
6446
6547 messages .append ({"role" : "assistant" , "content" : f"{ self .config ['start_think_token' ]} \n { self .config ['prefill' ]} " })
6648
@@ -75,27 +57,28 @@ def reasoning_effort(self, messages) -> str:
7557 n_thinking_tokens = 0
7658 seen_end_think = False
7759 response_chunks = []
78- current_pos = 0
7960
8061 while True :
8162 out = self .model (input_ids = tokens , past_key_values = kv , use_cache = True )
82-
83- # Apply TIP to logits
8463 logits = out .logits [0 , - 1 , :]
85- adjusted_logits = self .adjust_logits_with_tip (logits , current_pos )
8664
87- next_token = torch .multinomial (
88- torch .softmax (adjusted_logits , dim = - 1 ), 1
89- ).item ()
90- kv = out .past_key_values
65+ # Force end think token if we exceed limits
66+ if (n_thinking_tokens >= self .config ["max_thinking_tokens" ] or
67+ self .thought_count >= self .config ["max_thoughts" ]):
68+ next_token = self .end_think_token
69+ logger .debug (f"Forcing end think token. Tokens: { n_thinking_tokens } , Thoughts: { self .thought_count } " )
70+ else :
71+ next_token = torch .multinomial (
72+ torch .softmax (logits , dim = - 1 ), 1
73+ ).item ()
9174
75+ kv = out .past_key_values
9276 next_str = self .tokenizer .decode ([next_token ])
93- logger .debug (f"Generated token { next_token } -> '{ next_str } '" )
94-
77+
9578 # Check if this is a thought-switching token
9679 if next_token in self .thought_switch_tokens :
97- self .last_thought_switch_pos = current_pos
98- logger .debug (f"Detected thought switch at position { current_pos } " )
80+ self .thought_count += 1
81+ logger .debug (f"Detected thought switch. Total thoughts: { self . thought_count } " )
9982
10083 # Track if we've seen the end think token
10184 if next_token == self .end_think_token :
@@ -109,14 +92,15 @@ def reasoning_effort(self, messages) -> str:
10992 and n_thinking_tokens < self .config ["min_thinking_tokens" ])
11093 or (next_token == self .model .config .eos_token_id and not seen_end_think )):
11194
95+ # Insert thought transition
11296 replacement = random .choice (self .config ["thought_switch_tokens" ])
113- logger .debug (f"Inserting thought transition: '{ replacement } ' (tokens: { n_thinking_tokens } , seen_end_think: { seen_end_think } )" )
97+ logger .debug (f"Inserting thought transition: '{ replacement } ' (tokens: { n_thinking_tokens } )" )
11498 response_chunks .append (replacement )
11599 replacement_tokens = self .tokenizer .encode (replacement )
116100 n_thinking_tokens += len (replacement_tokens )
117101 tokens = torch .tensor ([replacement_tokens ]).to (tokens .device )
102+ self .thought_count += 1
118103 seen_end_think = False
119- logger .debug ("Reset seen_end_think flag after replacement" )
120104
121105 elif next_token == self .model .config .eos_token_id and seen_end_think :
122106 logger .debug ("Reached EOS after end think token - stopping generation" )
@@ -126,14 +110,12 @@ def reasoning_effort(self, messages) -> str:
126110 response_chunks .append (next_str )
127111 n_thinking_tokens += 1
128112 tokens = torch .tensor ([[next_token ]]).to (tokens .device )
129- current_pos += 1
130- logger .debug (f"Added token to response. Total thinking tokens: { n_thinking_tokens } " )
131113
132114 # Join all chunks and trim off the initial prompt
133115 response = "" .join (response_chunks )
134116 full_response = f"{ self .config ['start_think_token' ]} \n { self .config ['prefill' ]} { response } "
135117
136- logger .debug (f"Final response length: { len (full_response )} chars" )
118+ logger .debug (f"Final response length: { len (full_response )} chars, Total thoughts: { self . thought_count } " )
137119 return full_response
138120
139121def thinkdeeper_decode (
@@ -142,25 +124,24 @@ def thinkdeeper_decode(
142124 messages : List [Dict [str , str ]],
143125 request_config : Dict [str , Any ] = None
144126) -> str :
145- """Main plugin execution function with ThinkDeeper + TIP """
146- logger .info ("Starting ThinkDeeper+TIP processing" )
127+ """Main plugin execution function with ThinkDeeper's controlled thinking process """
128+ logger .info ("Starting ThinkDeeper processing" )
147129
148130 # Extract config from request_config if provided
149131 config = DEFAULT_CONFIG .copy ()
150132 if request_config :
151- thinkdeeper_config = request_config
152133 # Update only valid keys
153134 for key in DEFAULT_CONFIG :
154- if key in thinkdeeper_config :
155- config [key ] = thinkdeeper_config [key ]
135+ if key in request_config :
136+ config [key ] = request_config [key ]
156137
157138 logger .info (f"Using config: { config } " )
158139
159140 try :
160- processor = ThinkDeeperTIPProcessor (config , tokenizer , model )
141+ processor = ThinkDeeperProcessor (config , tokenizer , model )
161142 response = processor .reasoning_effort (messages )
162143 return response
163144
164145 except Exception as e :
165- logger .error (f"Error in ThinkDeeper+TIP processing: { str (e )} " )
146+ logger .error (f"Error in ThinkDeeper processing: { str (e )} " )
166147 raise
0 commit comments