Skip to content

Commit aad3062

Browse files
committed
support reasoning_effort parameter for reaosning models
1 parent 4b85f01 commit aad3062

File tree

3 files changed

+167
-49
lines changed

3 files changed

+167
-49
lines changed

optillm/inference.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,6 @@ def create(
13211321
presence_penalty: float = 0,
13221322
frequency_penalty: float = 0,
13231323
logit_bias: Optional[Dict[str, float]] = None,
1324-
user: Optional[str] = None,
13251324
seed: Optional[int] = None,
13261325
logprobs: Optional[bool] = None,
13271326
top_logprobs: Optional[int] = None,
@@ -1337,11 +1336,12 @@ def create(
13371336
# Entropy specific params
13381337
top_k: int = 27,
13391338
min_p: float = 0.03,
1340-
thought_switch_tokens: List[str] = ["Wait,", "Alternatively,"],
1341-
min_thinking_tokens: int = 512,
1342-
max_thinking_tokens: int = 2048,
1343-
max_thoughts: int = 4,
1344-
num_traces: int = 1,
1339+
# Thinking specific params
1340+
reasoning_effort: str = "low",
1341+
thought_switch_tokens: List[str] = [],
1342+
min_thinking_tokens: Optional[int] = None,
1343+
max_thinking_tokens: Optional[int] = None,
1344+
max_thoughts: Optional[int] = None,
13451345
prefill: str = "",
13461346
start_think_token: str ="<think>",
13471347
end_think_token: str = "</think>",
@@ -1443,15 +1443,21 @@ def create(
14431443
pipeline.current_model = pipeline.current_model.to(original_dtype)
14441444

14451445
elif decoding == "thinkdeeper":
1446-
thinkdeeper_config = {
1447-
"thought_switch_tokens": thought_switch_tokens,
1448-
"min_thinking_tokens": min_thinking_tokens,
1449-
"max_thinking_tokens": max_thinking_tokens,
1450-
"max_thoughts": max_thoughts,
1451-
"prefill": prefill,
1446+
# Get base config for reasoning effort
1447+
thinkdeeper_config = get_effort_profile(reasoning_effort)
1448+
1449+
# Override with any custom parameters
1450+
custom_config = {
1451+
"min_thinking_tokens": min_thinking_tokens if min_thinking_tokens is not None else thinkdeeper_config["min_thinking_tokens"],
1452+
"max_thinking_tokens": max_thinking_tokens if max_thinking_tokens is not None else thinkdeeper_config["max_thinking_tokens"],
1453+
"max_thoughts": max_thoughts if max_thoughts is not None else thinkdeeper_config["max_thoughts"],
1454+
"thought_switch_tokens": thought_switch_tokens if thought_switch_tokens else thinkdeeper_config["thought_switch_tokens"],
1455+
"prefill": prefill if prefill else thinkdeeper_config["prefill"],
14521456
"start_think_token": start_think_token,
14531457
"end_think_token": end_think_token,
14541458
}
1459+
thinkdeeper_config.update(custom_config)
1460+
14551461
result = thinkdeeper_decode(
14561462
pipeline.current_model,
14571463
pipeline.tokenizer,
@@ -1584,3 +1590,65 @@ def parse_model_string(model: str) -> ModelConfig:
15841590
enable_prompt_caching=False,
15851591
dynamic_temperature=False,
15861592
)
1593+
1594+
# Low Reasoning Effort
1595+
# Suitable for:
1596+
# - Simple, straightforward questions
1597+
# - Quick clarifications
1598+
# - Well-defined tasks with clear steps
1599+
LOW_EFFORT = {
1600+
"min_thinking_tokens": 256, # ~100-200 words minimum
1601+
"max_thinking_tokens": 512, # ~200-400 words maximum
1602+
"max_thoughts": 2, # Allow only one alternative perspective
1603+
"thought_switch_tokens": [
1604+
"However,", # Single alternative consideration
1605+
"Wait,",
1606+
"Alternatively,",
1607+
],
1608+
"prefill": "Let me think about this briefly..."
1609+
}
1610+
1611+
# Medium Reasoning Effort
1612+
# Suitable for:
1613+
# - Moderate complexity problems
1614+
# - Analysis requiring multiple perspectives
1615+
# - Tasks needing detailed explanation
1616+
MEDIUM_EFFORT = {
1617+
"min_thinking_tokens": 512, # ~200-400 words minimum
1618+
"max_thinking_tokens": 1024, # ~400-800 words maximum
1619+
"max_thoughts": 4, # Allow multiple perspective shifts
1620+
"thought_switch_tokens": [
1621+
"Additionally,",
1622+
"Alternatively,",
1623+
"However,",
1624+
"Wait,",
1625+
],
1626+
"prefill": "Let me analyze this from multiple angles..."
1627+
}
1628+
1629+
# High Reasoning Effort
1630+
# Suitable for:
1631+
# - Complex problem solving
1632+
# - Deep analysis tasks
1633+
# - Multi-step reasoning chains
1634+
HIGH_EFFORT = {
1635+
"min_thinking_tokens": 1024, # ~400-800 words minimum
1636+
"max_thinking_tokens": 2048, # ~800-1600 words maximum
1637+
"max_thoughts": 6, # Allow extensive exploration
1638+
"thought_switch_tokens": [
1639+
"Additionally,",
1640+
"Alternatively,",
1641+
"However,",
1642+
"Wait,",
1643+
],
1644+
"prefill": "This requires careful analysis. Let me think through it systematically..."
1645+
}
1646+
1647+
def get_effort_profile(effort_level: str) -> dict:
1648+
"""Get reasoning effort profile based on specified level."""
1649+
profiles = {
1650+
"low": LOW_EFFORT,
1651+
"medium": MEDIUM_EFFORT,
1652+
"high": HIGH_EFFORT
1653+
}
1654+
return profiles.get(effort_level, LOW_EFFORT)

optillm/thinkdeeper.py

Lines changed: 82 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66

77
logger = logging.getLogger(__name__)
8+
logger.setLevel(logging.DEBUG)
89

910
DEFAULT_CONFIG = {
1011
"min_thinking_tokens": 512,
@@ -31,14 +32,40 @@ def __init__(self, config: Dict[str, Any], tokenizer, model):
3132
self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1]
3233
self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1]
3334

34-
# Get token IDs for thought switching indicators
35-
self.thought_switch_tokens = set()
35+
# Store thought switch markers as token sequences
36+
self.thought_switch_sequences = []
3637
for phrase in self.config["thought_switch_tokens"]:
38+
# Encode without adding special tokens to get exact sequence
3739
token_ids = self.tokenizer.encode(phrase, add_special_tokens=False)
38-
self.thought_switch_tokens.update(token_ids)
40+
self.thought_switch_sequences.append(token_ids)
41+
logger.debug(f"Encoded '{phrase}' to token sequence: {token_ids}")
42+
logger.debug(f"Decoded back: {self.tokenizer.decode(token_ids)}")
3943

4044
# Track thought switches
4145
self.thought_count = 0
46+
self.current_sequence = [] # Track recent tokens for sequence matching
47+
self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences)
48+
49+
for phrase, sequence in zip(self.config["thought_switch_tokens"], self.thought_switch_sequences):
50+
logger.debug(f"Thought switch marker '{phrase}' encoded as: {sequence}")
51+
logger.debug(f"Decoded back as: {self.tokenizer.decode(sequence)}")
52+
53+
def is_thought_switch(self, token: int) -> bool:
54+
"""Check if adding this token creates a thought switch sequence."""
55+
# Add new token to current sequence
56+
self.current_sequence.append(token)
57+
58+
# Keep only the most recent tokens that could match our sequences
59+
if len(self.current_sequence) > self.max_sequence_length:
60+
self.current_sequence = self.current_sequence[-self.max_sequence_length:]
61+
62+
# Check if current sequence ends with any thought switch sequence
63+
for sequence in self.thought_switch_sequences:
64+
if len(sequence) <= len(self.current_sequence) and \
65+
self.current_sequence[-len(sequence):] == sequence:
66+
return True
67+
68+
return False
4269

4370
@torch.inference_mode()
4471
def reasoning_effort(self, messages) -> str:
@@ -62,11 +89,16 @@ def reasoning_effort(self, messages) -> str:
6289
out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True)
6390
logits = out.logits[0, -1, :]
6491

65-
# Force end think token if we exceed limits
66-
if (n_thinking_tokens >= self.config["max_thinking_tokens"] or
67-
self.thought_count >= self.config["max_thoughts"]):
68-
next_token = self.end_think_token
92+
# Check if we need to force end token
93+
force_end = (n_thinking_tokens >= self.config["max_thinking_tokens"] or
94+
self.thought_count >= self.config["max_thoughts"])
95+
96+
if force_end:
6997
logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}")
98+
next_token = self.end_think_token
99+
response_chunks.append(self.tokenizer.decode([next_token]))
100+
# Break immediately when forcing end token
101+
break
70102
else:
71103
next_token = torch.multinomial(
72104
torch.softmax(logits, dim=-1), 1
@@ -76,42 +108,56 @@ def reasoning_effort(self, messages) -> str:
76108
next_str = self.tokenizer.decode([next_token])
77109

78110
# Check if this is a thought-switching token
79-
if next_token in self.thought_switch_tokens:
111+
if self.is_thought_switch(next_token):
80112
self.thought_count += 1
81-
logger.debug(f"Detected thought switch. Total thoughts: {self.thought_count}")
113+
logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}")
114+
# Clear the sequence after detecting a switch
115+
self.current_sequence = []
82116

83-
# Track if we've seen the end think token
117+
# Handle natural end think token
84118
if next_token == self.end_think_token:
85119
seen_end_think = True
86120
logger.debug("Found end think token")
87-
88-
# Need to continue generating if:
89-
# 1. We hit end think/eos before min tokens OR
90-
# 2. We hit eos without seeing end think token
91-
if ((next_token in (self.end_think_token, self.model.config.eos_token_id)
92-
and n_thinking_tokens < self.config["min_thinking_tokens"])
93-
or (next_token == self.model.config.eos_token_id and not seen_end_think)):
94-
95-
# Insert thought transition
96-
replacement = random.choice(self.config["thought_switch_tokens"])
97-
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
98-
response_chunks.append(replacement)
99-
replacement_tokens = self.tokenizer.encode(replacement)
100-
n_thinking_tokens += len(replacement_tokens)
101-
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
102-
self.thought_count += 1
103-
seen_end_think = False
104-
105-
elif next_token == self.model.config.eos_token_id and seen_end_think:
106-
logger.debug("Reached EOS after end think token - stopping generation")
107-
break
108121

109-
else:
110-
response_chunks.append(next_str)
111-
n_thinking_tokens += 1
112-
tokens = torch.tensor([[next_token]]).to(tokens.device)
122+
# If we haven't reached minimum tokens, continue with thought transition
123+
if n_thinking_tokens < self.config["min_thinking_tokens"]:
124+
replacement = random.choice(self.config["thought_switch_tokens"])
125+
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
126+
response_chunks.append(replacement)
127+
replacement_tokens = self.tokenizer.encode(replacement)
128+
n_thinking_tokens += len(replacement_tokens)
129+
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
130+
self.thought_count += 1
131+
seen_end_think = False
132+
continue
133+
134+
# Handle EOS token
135+
if next_token == self.model.config.eos_token_id:
136+
if seen_end_think:
137+
logger.debug("Reached EOS after end think token - stopping generation")
138+
break
139+
elif n_thinking_tokens < self.config["min_thinking_tokens"]:
140+
# Continue with thought transition if under minimum tokens
141+
replacement = random.choice(self.config["thought_switch_tokens"])
142+
logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})")
143+
response_chunks.append(replacement)
144+
replacement_tokens = self.tokenizer.encode(replacement)
145+
n_thinking_tokens += len(replacement_tokens)
146+
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
147+
self.thought_count += 1
148+
continue
149+
else:
150+
# Force end think token if we haven't seen it
151+
logger.debug("Reached EOS without end think token - adding end token")
152+
response_chunks.append(self.tokenizer.decode([self.end_think_token]))
153+
break
154+
155+
# Normal token processing
156+
response_chunks.append(next_str)
157+
n_thinking_tokens += 1
158+
tokens = torch.tensor([[next_token]]).to(tokens.device)
113159

114-
# Join all chunks and trim off the initial prompt
160+
# Join all chunks and add framing tokens
115161
response = "".join(response_chunks)
116162
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}"
117163

scripts/eval_optillmbench.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ def evaluate_model(
160160
{"role": "user", "content": prompt}
161161
],
162162
temperature=0.2,
163-
max_tokens=4096
163+
max_tokens=4096,
164+
reasoning_effort="low",
165+
extra_body = {
166+
"decoding" : "thinkdeeper",
167+
}
164168
)
165169

166170
# Calculate time taken

0 commit comments

Comments
 (0)