Skip to content

Commit d2ed063

Browse files
committed
Add MLX ThinkDeeper support and update eval configs
Introduces a new MLX-compatible ThinkDeeper implementation (`thinkdeeper_mlx.py`) and integrates it into the inference pipeline for MLX models. Updates the inference logic to select the appropriate ThinkDeeper version based on the backend, and refines fallback and error handling for MLX generation. The evaluation script is updated with new test-time compute scaling approaches, including revised ThinkDeeper and majority voting configurations, and improved reporting for test-time compute experiments.
1 parent e4d0925 commit d2ed063

File tree

3 files changed

+502
-157
lines changed

3 files changed

+502
-157
lines changed

optillm/inference.py

Lines changed: 129 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from optillm.cot_decoding import cot_decode
2323
from optillm.entropy_decoding import entropy_decode
2424
from optillm.thinkdeeper import thinkdeeper_decode
25+
from optillm.thinkdeeper_mlx import thinkdeeper_decode_mlx
2526
from optillm.autothink import autothink_decode
2627

2728
# Configure logging
@@ -33,6 +34,7 @@
3334
import mlx.core as mx
3435
from mlx_lm import load as mlx_load, generate as mlx_generate
3536
from mlx_lm.tokenizer_utils import TokenizerWrapper
37+
from mlx_lm.sample_utils import make_sampler
3638
MLX_AVAILABLE = True
3739
logger.info("MLX framework available")
3840
except ImportError:
@@ -349,85 +351,46 @@ def generate(
349351
return responses, token_counts, logprobs_results
350352

351353
def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str:
352-
"""Robust MLX generation with multiple parameter combinations"""
353-
354-
# Try different parameter combinations based on MLX-LM version
355-
parameter_combinations = [
356-
# Version 1: Current style with positional args and temp
357-
{
358-
"style": "positional_temp",
359-
"args": (self.model, self.tokenizer, prompt),
360-
"kwargs": {
361-
"max_tokens": max_tokens,
362-
"temp": temperature,
363-
"top_p": top_p,
364-
"repetition_penalty": repetition_penalty,
365-
"verbose": False
366-
}
367-
},
368-
# Version 2: All keyword arguments with temp
369-
{
370-
"style": "keyword_temp",
371-
"args": (),
372-
"kwargs": {
373-
"model": self.model,
374-
"tokenizer": self.tokenizer,
375-
"prompt": prompt,
376-
"max_tokens": max_tokens,
377-
"temp": temperature,
378-
"top_p": top_p,
379-
"repetition_penalty": repetition_penalty,
380-
"verbose": False
381-
}
382-
},
383-
# Version 3: Using temperature instead of temp
384-
{
385-
"style": "positional_temperature",
386-
"args": (self.model, self.tokenizer, prompt),
387-
"kwargs": {
388-
"max_tokens": max_tokens,
389-
"temperature": temperature,
390-
"top_p": top_p,
391-
"repetition_penalty": repetition_penalty,
392-
"verbose": False
393-
}
394-
},
395-
# Version 4: Minimal parameters only
396-
{
397-
"style": "minimal",
398-
"args": (self.model, self.tokenizer, prompt),
399-
"kwargs": {
400-
"max_tokens": max_tokens,
401-
"temp": temperature,
402-
"verbose": False
403-
}
404-
},
405-
# Version 5: Just essential parameters
406-
{
407-
"style": "essential",
408-
"args": (self.model, self.tokenizer, prompt),
409-
"kwargs": {
410-
"max_tokens": max_tokens
411-
}
412-
}
413-
]
354+
"""Robust MLX generation using sampler approach"""
414355

415-
last_error = None
416-
417-
for combo in parameter_combinations:
356+
try:
357+
# Create sampler with generation parameters
358+
sampler = make_sampler(
359+
temp=temperature,
360+
top_p=top_p,
361+
min_p=0.0, # Default min_p
362+
min_tokens_to_keep=1 # Default min_tokens_to_keep
363+
)
364+
365+
# Generate using the sampler
366+
response = mlx_generate(
367+
self.model,
368+
self.tokenizer,
369+
prompt,
370+
max_tokens=max_tokens,
371+
sampler=sampler,
372+
verbose=False
373+
)
374+
375+
return response
376+
377+
except Exception as e:
378+
logger.error(f"MLX generation with sampler failed: {str(e)}")
379+
380+
# Fallback: Try minimal parameters without sampler
418381
try:
419-
logger.debug(f"Trying MLX generation with style: {combo['style']}")
420-
response = mlx_generate(*combo["args"], **combo["kwargs"])
421-
logger.debug(f"Successfully generated with style: {combo['style']}")
382+
logger.debug("Attempting MLX generation without sampler")
383+
response = mlx_generate(
384+
self.model,
385+
self.tokenizer,
386+
prompt,
387+
max_tokens=max_tokens,
388+
verbose=False
389+
)
422390
return response
423-
424-
except Exception as e:
425-
last_error = e
426-
logger.debug(f"Failed with style {combo['style']}: {str(e)}")
427-
continue
428-
429-
# If all combinations failed, raise the last error
430-
raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}")
391+
except Exception as fallback_e:
392+
logger.error(f"MLX fallback generation also failed: {str(fallback_e)}")
393+
raise
431394

432395
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
433396
"""Format the prompt according to model's chat template"""
@@ -1691,37 +1654,47 @@ def create(
16911654
if decoding:
16921655
logger.info(f"Using specialized decoding approach: {decoding}")
16931656

1694-
# Ensure model is in eval mode and on correct device
1695-
pipeline.current_model.eval()
1696-
device = pipeline.current_model.device
1657+
# Check if this decoding approach is supported for MLX
1658+
mlx_unsupported_decodings = ["cot_decoding", "entropy_decoding", "autothink"]
1659+
if isinstance(pipeline, MLXInferencePipeline) and decoding in mlx_unsupported_decodings:
1660+
logger.warning(f"{decoding} is not supported for MLX models. Falling back to standard generation.")
1661+
decoding = None
1662+
1663+
if decoding:
1664+
# For PyTorch pipelines, ensure model is in eval mode and get device
1665+
# MLX pipelines handle this differently
1666+
if not isinstance(pipeline, MLXInferencePipeline):
1667+
pipeline.current_model.eval()
1668+
device = pipeline.current_model.device
1669+
else:
1670+
device = None # MLX doesn't use torch devices
16971671

16981672
if decoding == "cot_decoding":
16991673
# Use directly available parameters for CoT
1700-
cot_params = {
1701-
"k": k,
1702-
"num_beams": num_beams,
1703-
"max_new_tokens": max_tokens if max_tokens is not None else 512,
1704-
"temperature": temperature,
1705-
"top_p": top_p,
1706-
"repetition_penalty": 1.0,
1707-
"length_penalty": length_penalty,
1708-
"no_repeat_ngram_size": no_repeat_ngram_size,
1709-
"early_stopping": early_stopping,
1710-
"aggregate_paths": aggregate_paths,
1711-
}
1712-
1713-
result, confidence = cot_decode(
1714-
pipeline.current_model,
1715-
pipeline.tokenizer,
1716-
messages,
1717-
**cot_params
1718-
)
1719-
responses = [result]
1720-
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
1721-
completion_tokens = len(pipeline.tokenizer.encode(result))
1674+
cot_params = {
1675+
"k": k,
1676+
"num_beams": num_beams,
1677+
"max_new_tokens": max_tokens if max_tokens is not None else 512,
1678+
"temperature": temperature,
1679+
"top_p": top_p,
1680+
"repetition_penalty": 1.0,
1681+
"length_penalty": length_penalty,
1682+
"no_repeat_ngram_size": no_repeat_ngram_size,
1683+
"early_stopping": early_stopping,
1684+
"aggregate_paths": aggregate_paths,
1685+
}
1686+
1687+
result, confidence = cot_decode(
1688+
pipeline.current_model,
1689+
pipeline.tokenizer,
1690+
messages,
1691+
**cot_params
1692+
)
1693+
responses = [result]
1694+
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
1695+
completion_tokens = len(pipeline.tokenizer.encode(result))
17221696

17231697
elif decoding == "entropy_decoding":
1724-
17251698
# Ensure model is using full precision
17261699
original_dtype = pipeline.current_model.dtype
17271700
pipeline.current_model = pipeline.current_model.to(torch.float32)
@@ -1778,43 +1751,66 @@ def create(
17781751
}
17791752
thinkdeeper_config.update(custom_config)
17801753

1781-
result = thinkdeeper_decode(
1782-
pipeline.current_model,
1783-
pipeline.tokenizer,
1784-
messages,
1785-
thinkdeeper_config
1754+
# Check if we're using MLX pipeline
1755+
if isinstance(pipeline, MLXInferencePipeline):
1756+
logger.info("Using MLX ThinkDeeper implementation")
1757+
1758+
# Ensure we have enough tokens for thinking + response
1759+
user_max_tokens = max_tokens if max_tokens is not None else 512
1760+
total_tokens_needed = max_thinking_tokens + 512 # thinking + response buffer
1761+
adjusted_max_tokens = max(user_max_tokens, total_tokens_needed)
1762+
1763+
# Add max_tokens to thinkdeeper config
1764+
thinkdeeper_config_with_tokens = thinkdeeper_config.copy()
1765+
thinkdeeper_config_with_tokens["max_tokens"] = adjusted_max_tokens
1766+
1767+
logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}")
1768+
1769+
result = thinkdeeper_decode_mlx(
1770+
pipeline.model,
1771+
pipeline.tokenizer,
1772+
messages,
1773+
thinkdeeper_config_with_tokens
1774+
)
1775+
else:
1776+
logger.info("Using PyTorch ThinkDeeper implementation")
1777+
result = thinkdeeper_decode(
1778+
pipeline.current_model,
1779+
pipeline.tokenizer,
1780+
messages,
1781+
thinkdeeper_config
17861782
)
17871783
responses = [result]
17881784
logprobs_results = [None]
17891785
completion_tokens = len(pipeline.tokenizer.encode(result))
17901786
elif decoding == "autothink":
17911787
# Get steering dataset configuration
1792-
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
1793-
target_layer = kwargs.get("target_layer", 19)
1794-
1795-
# Prepare AutoThink configuration
1796-
autothink_config = {
1797-
"steering_dataset": steering_dataset,
1798-
"target_layer": target_layer,
1799-
"pattern_strengths": kwargs.get("pattern_strengths", {
1800-
"depth_and_thoroughness": 2.5,
1801-
"numerical_accuracy": 2.0,
1802-
"self_correction": 3.0,
1803-
"exploration": 2.0,
1804-
"organization": 1.5
1805-
})
1806-
}
1807-
1808-
# Process with AutoThink
1809-
result = autothink_decode(
1810-
pipeline.current_model,
1811-
pipeline.tokenizer,
1812-
messages,
1813-
autothink_config
1814-
)
1815-
responses = [result]
1816-
logprobs_results = [None]
1817-
completion_tokens = len(pipeline.tokenizer.encode(result))
1788+
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
1789+
target_layer = kwargs.get("target_layer", 19)
1790+
1791+
# Prepare AutoThink configuration
1792+
autothink_config = {
1793+
"steering_dataset": steering_dataset,
1794+
"target_layer": target_layer,
1795+
"pattern_strengths": kwargs.get("pattern_strengths", {
1796+
"depth_and_thoroughness": 2.5,
1797+
"numerical_accuracy": 2.0,
1798+
"self_correction": 3.0,
1799+
"exploration": 2.0,
1800+
"organization": 1.5
1801+
})
1802+
}
1803+
1804+
# Process with AutoThink
1805+
result = autothink_decode(
1806+
pipeline.current_model,
1807+
pipeline.tokenizer,
1808+
messages,
1809+
autothink_config
1810+
)
1811+
responses = [result]
1812+
logprobs_results = [None]
1813+
completion_tokens = len(pipeline.tokenizer.encode(result))
18181814
else:
18191815
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
18201816

0 commit comments

Comments
 (0)