Skip to content

Commit ad262d8

Browse files
authored
Merge pull request #208 from codelion/feat-maj-at-k-plugin
Add majority voting plugin for candidate selection
2 parents 7904463 + cf2578e commit ad262d8

File tree

9 files changed

+997
-200
lines changed

9 files changed

+997
-200
lines changed

optillm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
# Version information
5-
__version__ = "0.1.19"
5+
__version__ = "0.1.20"
66

77
# Get the path to the root optillm.py
88
spec = util.spec_from_file_location(

optillm/bon.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,45 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
1010

1111
completions = []
1212

13-
response = client.chat.completions.create(
14-
model=model,
15-
messages=messages,
16-
max_tokens=4096,
17-
n=n,
18-
temperature=1
19-
)
20-
completions = [choice.message.content for choice in response.choices]
21-
logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}")
22-
bon_completion_tokens += response.usage.completion_tokens
13+
try:
14+
# Try to generate n completions in a single API call using n parameter
15+
response = client.chat.completions.create(
16+
model=model,
17+
messages=messages,
18+
max_tokens=4096,
19+
n=n,
20+
temperature=1
21+
)
22+
completions = [choice.message.content for choice in response.choices]
23+
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
24+
bon_completion_tokens += response.usage.completion_tokens
25+
26+
except Exception as e:
27+
logger.warning(f"n parameter not supported by provider: {str(e)}")
28+
logger.info(f"Falling back to generating {n} completions one by one")
29+
30+
# Fallback: Generate completions one by one in a loop
31+
for i in range(n):
32+
try:
33+
response = client.chat.completions.create(
34+
model=model,
35+
messages=messages,
36+
max_tokens=4096,
37+
temperature=1
38+
)
39+
completions.append(response.choices[0].message.content)
40+
bon_completion_tokens += response.usage.completion_tokens
41+
logger.debug(f"Generated completion {i+1}/{n}")
42+
43+
except Exception as fallback_error:
44+
logger.error(f"Error generating completion {i+1}: {str(fallback_error)}")
45+
continue
46+
47+
if not completions:
48+
logger.error("Failed to generate any completions")
49+
return "Error: Could not generate any completions", 0
50+
51+
logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}")
2352

2453
# Rate the completions
2554
rating_messages = messages.copy()

optillm/inference.py

Lines changed: 129 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from optillm.cot_decoding import cot_decode
2323
from optillm.entropy_decoding import entropy_decode
2424
from optillm.thinkdeeper import thinkdeeper_decode
25+
from optillm.thinkdeeper_mlx import thinkdeeper_decode_mlx
2526
from optillm.autothink import autothink_decode
2627

2728
# Configure logging
@@ -33,6 +34,7 @@
3334
import mlx.core as mx
3435
from mlx_lm import load as mlx_load, generate as mlx_generate
3536
from mlx_lm.tokenizer_utils import TokenizerWrapper
37+
from mlx_lm.sample_utils import make_sampler
3638
MLX_AVAILABLE = True
3739
logger.info("MLX framework available")
3840
except ImportError:
@@ -349,85 +351,46 @@ def generate(
349351
return responses, token_counts, logprobs_results
350352

351353
def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str:
352-
"""Robust MLX generation with multiple parameter combinations"""
353-
354-
# Try different parameter combinations based on MLX-LM version
355-
parameter_combinations = [
356-
# Version 1: Current style with positional args and temp
357-
{
358-
"style": "positional_temp",
359-
"args": (self.model, self.tokenizer, prompt),
360-
"kwargs": {
361-
"max_tokens": max_tokens,
362-
"temp": temperature,
363-
"top_p": top_p,
364-
"repetition_penalty": repetition_penalty,
365-
"verbose": False
366-
}
367-
},
368-
# Version 2: All keyword arguments with temp
369-
{
370-
"style": "keyword_temp",
371-
"args": (),
372-
"kwargs": {
373-
"model": self.model,
374-
"tokenizer": self.tokenizer,
375-
"prompt": prompt,
376-
"max_tokens": max_tokens,
377-
"temp": temperature,
378-
"top_p": top_p,
379-
"repetition_penalty": repetition_penalty,
380-
"verbose": False
381-
}
382-
},
383-
# Version 3: Using temperature instead of temp
384-
{
385-
"style": "positional_temperature",
386-
"args": (self.model, self.tokenizer, prompt),
387-
"kwargs": {
388-
"max_tokens": max_tokens,
389-
"temperature": temperature,
390-
"top_p": top_p,
391-
"repetition_penalty": repetition_penalty,
392-
"verbose": False
393-
}
394-
},
395-
# Version 4: Minimal parameters only
396-
{
397-
"style": "minimal",
398-
"args": (self.model, self.tokenizer, prompt),
399-
"kwargs": {
400-
"max_tokens": max_tokens,
401-
"temp": temperature,
402-
"verbose": False
403-
}
404-
},
405-
# Version 5: Just essential parameters
406-
{
407-
"style": "essential",
408-
"args": (self.model, self.tokenizer, prompt),
409-
"kwargs": {
410-
"max_tokens": max_tokens
411-
}
412-
}
413-
]
354+
"""Robust MLX generation using sampler approach"""
414355

415-
last_error = None
416-
417-
for combo in parameter_combinations:
356+
try:
357+
# Create sampler with generation parameters
358+
sampler = make_sampler(
359+
temp=temperature,
360+
top_p=top_p,
361+
min_p=0.0, # Default min_p
362+
min_tokens_to_keep=1 # Default min_tokens_to_keep
363+
)
364+
365+
# Generate using the sampler
366+
response = mlx_generate(
367+
self.model,
368+
self.tokenizer,
369+
prompt,
370+
max_tokens=max_tokens,
371+
sampler=sampler,
372+
verbose=False
373+
)
374+
375+
return response
376+
377+
except Exception as e:
378+
logger.error(f"MLX generation with sampler failed: {str(e)}")
379+
380+
# Fallback: Try minimal parameters without sampler
418381
try:
419-
logger.debug(f"Trying MLX generation with style: {combo['style']}")
420-
response = mlx_generate(*combo["args"], **combo["kwargs"])
421-
logger.debug(f"Successfully generated with style: {combo['style']}")
382+
logger.debug("Attempting MLX generation without sampler")
383+
response = mlx_generate(
384+
self.model,
385+
self.tokenizer,
386+
prompt,
387+
max_tokens=max_tokens,
388+
verbose=False
389+
)
422390
return response
423-
424-
except Exception as e:
425-
last_error = e
426-
logger.debug(f"Failed with style {combo['style']}: {str(e)}")
427-
continue
428-
429-
# If all combinations failed, raise the last error
430-
raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}")
391+
except Exception as fallback_e:
392+
logger.error(f"MLX fallback generation also failed: {str(fallback_e)}")
393+
raise
431394

432395
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
433396
"""Format the prompt according to model's chat template"""
@@ -1691,37 +1654,47 @@ def create(
16911654
if decoding:
16921655
logger.info(f"Using specialized decoding approach: {decoding}")
16931656

1694-
# Ensure model is in eval mode and on correct device
1695-
pipeline.current_model.eval()
1696-
device = pipeline.current_model.device
1657+
# Check if this decoding approach is supported for MLX
1658+
mlx_unsupported_decodings = ["cot_decoding", "entropy_decoding", "autothink"]
1659+
if isinstance(pipeline, MLXInferencePipeline) and decoding in mlx_unsupported_decodings:
1660+
logger.warning(f"{decoding} is not supported for MLX models. Falling back to standard generation.")
1661+
decoding = None
1662+
1663+
if decoding:
1664+
# For PyTorch pipelines, ensure model is in eval mode and get device
1665+
# MLX pipelines handle this differently
1666+
if not isinstance(pipeline, MLXInferencePipeline):
1667+
pipeline.current_model.eval()
1668+
device = pipeline.current_model.device
1669+
else:
1670+
device = None # MLX doesn't use torch devices
16971671

16981672
if decoding == "cot_decoding":
16991673
# Use directly available parameters for CoT
1700-
cot_params = {
1701-
"k": k,
1702-
"num_beams": num_beams,
1703-
"max_new_tokens": max_tokens if max_tokens is not None else 512,
1704-
"temperature": temperature,
1705-
"top_p": top_p,
1706-
"repetition_penalty": 1.0,
1707-
"length_penalty": length_penalty,
1708-
"no_repeat_ngram_size": no_repeat_ngram_size,
1709-
"early_stopping": early_stopping,
1710-
"aggregate_paths": aggregate_paths,
1711-
}
1712-
1713-
result, confidence = cot_decode(
1714-
pipeline.current_model,
1715-
pipeline.tokenizer,
1716-
messages,
1717-
**cot_params
1718-
)
1719-
responses = [result]
1720-
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
1721-
completion_tokens = len(pipeline.tokenizer.encode(result))
1674+
cot_params = {
1675+
"k": k,
1676+
"num_beams": num_beams,
1677+
"max_new_tokens": max_tokens if max_tokens is not None else 512,
1678+
"temperature": temperature,
1679+
"top_p": top_p,
1680+
"repetition_penalty": 1.0,
1681+
"length_penalty": length_penalty,
1682+
"no_repeat_ngram_size": no_repeat_ngram_size,
1683+
"early_stopping": early_stopping,
1684+
"aggregate_paths": aggregate_paths,
1685+
}
1686+
1687+
result, confidence = cot_decode(
1688+
pipeline.current_model,
1689+
pipeline.tokenizer,
1690+
messages,
1691+
**cot_params
1692+
)
1693+
responses = [result]
1694+
logprobs_results = [{"confidence_score": confidence} if confidence is not None else None]
1695+
completion_tokens = len(pipeline.tokenizer.encode(result))
17221696

17231697
elif decoding == "entropy_decoding":
1724-
17251698
# Ensure model is using full precision
17261699
original_dtype = pipeline.current_model.dtype
17271700
pipeline.current_model = pipeline.current_model.to(torch.float32)
@@ -1778,43 +1751,66 @@ def create(
17781751
}
17791752
thinkdeeper_config.update(custom_config)
17801753

1781-
result = thinkdeeper_decode(
1782-
pipeline.current_model,
1783-
pipeline.tokenizer,
1784-
messages,
1785-
thinkdeeper_config
1754+
# Check if we're using MLX pipeline
1755+
if isinstance(pipeline, MLXInferencePipeline):
1756+
logger.info("Using MLX ThinkDeeper implementation")
1757+
1758+
# Ensure we have enough tokens for thinking + response
1759+
user_max_tokens = max_tokens if max_tokens is not None else 512
1760+
total_tokens_needed = max_thinking_tokens + 512 # thinking + response buffer
1761+
adjusted_max_tokens = max(user_max_tokens, total_tokens_needed)
1762+
1763+
# Add max_tokens to thinkdeeper config
1764+
thinkdeeper_config_with_tokens = thinkdeeper_config.copy()
1765+
thinkdeeper_config_with_tokens["max_tokens"] = adjusted_max_tokens
1766+
1767+
logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}")
1768+
1769+
result = thinkdeeper_decode_mlx(
1770+
pipeline.model,
1771+
pipeline.tokenizer,
1772+
messages,
1773+
thinkdeeper_config_with_tokens
1774+
)
1775+
else:
1776+
logger.info("Using PyTorch ThinkDeeper implementation")
1777+
result = thinkdeeper_decode(
1778+
pipeline.current_model,
1779+
pipeline.tokenizer,
1780+
messages,
1781+
thinkdeeper_config
17861782
)
17871783
responses = [result]
17881784
logprobs_results = [None]
17891785
completion_tokens = len(pipeline.tokenizer.encode(result))
17901786
elif decoding == "autothink":
17911787
# Get steering dataset configuration
1792-
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
1793-
target_layer = kwargs.get("target_layer", 19)
1794-
1795-
# Prepare AutoThink configuration
1796-
autothink_config = {
1797-
"steering_dataset": steering_dataset,
1798-
"target_layer": target_layer,
1799-
"pattern_strengths": kwargs.get("pattern_strengths", {
1800-
"depth_and_thoroughness": 2.5,
1801-
"numerical_accuracy": 2.0,
1802-
"self_correction": 3.0,
1803-
"exploration": 2.0,
1804-
"organization": 1.5
1805-
})
1806-
}
1807-
1808-
# Process with AutoThink
1809-
result = autothink_decode(
1810-
pipeline.current_model,
1811-
pipeline.tokenizer,
1812-
messages,
1813-
autothink_config
1814-
)
1815-
responses = [result]
1816-
logprobs_results = [None]
1817-
completion_tokens = len(pipeline.tokenizer.encode(result))
1788+
steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors")
1789+
target_layer = kwargs.get("target_layer", 19)
1790+
1791+
# Prepare AutoThink configuration
1792+
autothink_config = {
1793+
"steering_dataset": steering_dataset,
1794+
"target_layer": target_layer,
1795+
"pattern_strengths": kwargs.get("pattern_strengths", {
1796+
"depth_and_thoroughness": 2.5,
1797+
"numerical_accuracy": 2.0,
1798+
"self_correction": 3.0,
1799+
"exploration": 2.0,
1800+
"organization": 1.5
1801+
})
1802+
}
1803+
1804+
# Process with AutoThink
1805+
result = autothink_decode(
1806+
pipeline.current_model,
1807+
pipeline.tokenizer,
1808+
messages,
1809+
autothink_config
1810+
)
1811+
responses = [result]
1812+
logprobs_results = [None]
1813+
completion_tokens = len(pipeline.tokenizer.encode(result))
18181814
else:
18191815
raise ValueError(f"Unknown specialized decoding approach: {decoding}")
18201816

0 commit comments

Comments
 (0)