Skip to content

Commit 4b85f01

Browse files
committed
removed TIP added max_thoughts token
1 parent f89f89e commit 4b85f01

File tree

2 files changed

+35
-55
lines changed

2 files changed

+35
-55
lines changed

optillm/inference.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,8 +1339,8 @@ def create(
13391339
min_p: float = 0.03,
13401340
thought_switch_tokens: List[str] = ["Wait,", "Alternatively,"],
13411341
min_thinking_tokens: int = 512,
1342-
tip_alpha: float = 4.0,
1343-
tip_beta: int = 1024,
1342+
max_thinking_tokens: int = 2048,
1343+
max_thoughts: int = 4,
13441344
num_traces: int = 1,
13451345
prefill: str = "",
13461346
start_think_token: str ="<think>",
@@ -1446,12 +1446,11 @@ def create(
14461446
thinkdeeper_config = {
14471447
"thought_switch_tokens": thought_switch_tokens,
14481448
"min_thinking_tokens": min_thinking_tokens,
1449+
"max_thinking_tokens": max_thinking_tokens,
1450+
"max_thoughts": max_thoughts,
14491451
"prefill": prefill,
1450-
"start_think_token" : start_think_token,
1451-
"end_think_token" : end_think_token,
1452-
"num_traces" : num_traces,
1453-
"tip_alpha" : tip_alpha,
1454-
"tip_beta" : tip_beta,
1452+
"start_think_token": start_think_token,
1453+
"end_think_token": end_think_token,
14551454
}
14561455
result = thinkdeeper_decode(
14571456
pipeline.current_model,

optillm/thinkdeeper.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,20 @@
66

77
logger = logging.getLogger(__name__)
88

9-
# Default configurations
109
DEFAULT_CONFIG = {
1110
"min_thinking_tokens": 512,
11+
"max_thinking_tokens": 2048, # New parameter to cap thinking length
12+
"max_thoughts": 4, # New parameter to limit number of thought transitions
1213
"prefill": "",
1314
"start_think_token": "<think>",
1415
"end_think_token": "</think>",
15-
16-
# Combined thought transition markers and TIP configs
17-
"tip_alpha": 4.0, # Penalty strength
18-
"tip_beta": 1024, # Penalty duration (number of tokens)
1916
"thought_switch_tokens": [
2017
"Wait,",
2118
"Alternatively,",
2219
],
2320
}
2421

25-
class ThinkDeeperTIPProcessor:
22+
class ThinkDeeperProcessor:
2623
def __init__(self, config: Dict[str, Any], tokenizer, model):
2724
self.config = {**DEFAULT_CONFIG, **config}
2825
self.tokenizer = tokenizer
@@ -40,27 +37,12 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
4037
token_ids = self.tokenizer.encode(phrase, add_special_tokens=False)
4138
self.thought_switch_tokens.update(token_ids)
4239

43-
# Track when the last thought switch occurred
44-
self.last_thought_switch_pos = 0
45-
46-
def adjust_logits_with_tip(self, logits: torch.Tensor, current_pos: int) -> torch.Tensor:
47-
"""Apply Thought Switching Penalty (TIP) to logits"""
48-
tokens_since_last_switch = current_pos - self.last_thought_switch_pos
49-
50-
if tokens_since_last_switch < self.config["tip_beta"]:
51-
penalty_mask = torch.zeros_like(logits)
52-
for token_id in self.thought_switch_tokens:
53-
if token_id < logits.size(-1): # Ensure token_id is within valid range
54-
penalty_mask[token_id] = self.config["tip_alpha"]
55-
56-
adjusted_logits = logits - penalty_mask
57-
return adjusted_logits
40+
# Track thought switches
41+
self.thought_count = 0
5842

59-
return logits
60-
6143
@torch.inference_mode()
6244
def reasoning_effort(self, messages) -> str:
63-
"""Generate response with ThinkDeeper + TIP"""
45+
"""Generate response with ThinkDeeper's controlled thinking process"""
6446

6547
messages.append({"role": "assistant", "content": f"{self.config['start_think_token']}\n{self.config['prefill']}"})
6648

@@ -75,27 +57,28 @@ def reasoning_effort(self, messages) -> str:
7557
n_thinking_tokens = 0
7658
seen_end_think = False
7759
response_chunks = []
78-
current_pos = 0
7960

8061
while True:
8162
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
82-
83-
# Apply TIP to logits
8463
logits = out.logits[0, -1, :]
85-
adjusted_logits = self.adjust_logits_with_tip(logits, current_pos)
8664

87-
next_token = torch.multinomial(
88-
torch.softmax(adjusted_logits, dim=-1), 1
89-
).item()
90-
kv = out.past_key_values
65+
# Force end think token if we exceed limits
66+
if (n_thinking_tokens >= self.config["max_thinking_tokens"] or
67+
self.thought_count >= self.config["max_thoughts"]):
68+
next_token = self.end_think_token
69+
logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}")
70+
else:
71+
next_token = torch.multinomial(
72+
torch.softmax(logits, dim=-1), 1
73+
).item()
9174

75+
kv = out.past_key_values
9276
next_str = self.tokenizer.decode([next_token])
93-
logger.debug(f"Generated token {next_token} -> '{next_str}'")
94-
77+
9578
# Check if this is a thought-switching token
9679
if next_token in self.thought_switch_tokens:
97-
self.last_thought_switch_pos = current_pos
98-
logger.debug(f"Detected thought switch at position {current_pos}")
80+
self.thought_count += 1
81+
logger.debug(f"Detected thought switch. Total thoughts: {self.thought_count}")
9982

10083
# Track if we've seen the end think token
10184
if next_token == self.end_think_token:
@@ -109,14 +92,15 @@ def reasoning_effort(self, messages) -> str:
10992
and n_thinking_tokens < self.config["min_thinking_tokens"])
11093
or (next_token == self.model.config.eos_token_id and not seen_end_think)):
11194

95+
# Insert thought transition
11296
replacement = random.choice(self.config["thought_switch_tokens"])
113-
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens}, seen_end_think: {seen_end_think})")
97+
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
11498
response_chunks.append(replacement)
11599
replacement_tokens = self.tokenizer.encode(replacement)
116100
n_thinking_tokens += len(replacement_tokens)
117101
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
102+
self.thought_count += 1
118103
seen_end_think = False
119-
logger.debug("Reset seen_end_think flag after replacement")
120104

121105
elif next_token == self.model.config.eos_token_id and seen_end_think:
122106
logger.debug("Reached EOS after end think token - stopping generation")
@@ -126,14 +110,12 @@ def reasoning_effort(self, messages) -> str:
126110
response_chunks.append(next_str)
127111
n_thinking_tokens += 1
128112
tokens = torch.tensor([[next_token]]).to(tokens.device)
129-
current_pos += 1
130-
logger.debug(f"Added token to response. Total thinking tokens: {n_thinking_tokens}")
131113

132114
# Join all chunks and trim off the initial prompt
133115
response = "".join(response_chunks)
134116
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}"
135117

136-
logger.debug(f"Final response length: {len(full_response)} chars")
118+
logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}")
137119
return full_response
138120

139121
def thinkdeeper_decode(
@@ -142,25 +124,24 @@ def thinkdeeper_decode(
142124
messages: List[Dict[str, str]],
143125
request_config: Dict[str, Any] = None
144126
) -> str:
145-
"""Main plugin execution function with ThinkDeeper + TIP"""
146-
logger.info("Starting ThinkDeeper+TIP processing")
127+
"""Main plugin execution function with ThinkDeeper's controlled thinking process"""
128+
logger.info("Starting ThinkDeeper processing")
147129

148130
# Extract config from request_config if provided
149131
config = DEFAULT_CONFIG.copy()
150132
if request_config:
151-
thinkdeeper_config = request_config
152133
# Update only valid keys
153134
for key in DEFAULT_CONFIG:
154-
if key in thinkdeeper_config:
155-
config[key] = thinkdeeper_config[key]
135+
if key in request_config:
136+
config[key] = request_config[key]
156137

157138
logger.info(f"Using config: {config}")
158139

159140
try:
160-
processor = ThinkDeeperTIPProcessor(config, tokenizer, model)
141+
processor = ThinkDeeperProcessor(config, tokenizer, model)
161142
response = processor.reasoning_effort(messages)
162143
return response
163144

164145
except Exception as e:
165-
logger.error(f"Error in ThinkDeeper+TIP processing: {str(e)}")
146+
logger.error(f"Error in ThinkDeeper processing: {str(e)}")
166147
raise

0 commit comments

Comments
 (0)