Skip to content

Commit 23de918

Browse files
committed
init implementation
1 parent 2ab4e6e commit 23de918

File tree

11 files changed

+1452
-1
lines changed

11 files changed

+1452
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ Check this log file for connection issues, tool execution errors, and other diag
343343

344344
| Approach | Slug | Description |
345345
| ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- |
346-
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
346+
| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques |
347347
| CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \<thinking\>, \<reflection> and \<output\> sections |
348348
| PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language |
349349
| ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice |
@@ -359,6 +359,7 @@ Check this log file for connection issues, tool execution errors, and other diag
359359
| CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting |
360360
| Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation |
361361
| Thinkdeeper | N/A for proxy | Implements the `reasoning_effort` param from OpenAI for reasoning models like DeepSeek R1 |
362+
| AutoThink | N/A for proxy | Combines query complexity classification with steering vectors to enhance reasoning |
362363

363364
## Implemented plugins
364365

optillm/autothink/README.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# AutoThink
2+
3+
AutoThink is an adaptive thinking approach for Large Language Models that combines query complexity classification with steering vector guidance to enhance model reasoning capabilities.
4+
5+
## Overview
6+
7+
AutoThink combines several advanced techniques to optimize the thinking process of LLMs:
8+
9+
1. **Query Complexity Classification**: Uses an adaptive classifier to determine if a query requires HIGH or LOW complexity reasoning
10+
2. **Token Budget Allocation**: Dynamically allocates thinking tokens based on query complexity
11+
3. **Steering Vector Guidance**: Applies activation-based steering vectors to guide the model's reasoning process
12+
4. **Controlled Thinking Process**: Manages explicit thinking phases with start and end tokens
13+
14+
## How It Works
15+
16+
### 1. Query Classification
17+
18+
AutoThink uses the `adaptive-classifier/llm-router` model to classify incoming queries:
19+
20+
- **HIGH**: Complex queries requiring deep reasoning, multi-step calculations, or thorough exploration
21+
- **LOW**: Simpler queries requiring less extensive reasoning
22+
23+
### 2. Token Budget
24+
25+
Based on the classification, AutoThink allocates different token budgets for the thinking phase:
26+
27+
- **HIGH**: 70-90% of max tokens allocated for thinking
28+
- **LOW**: 20-40% of max tokens allocated for thinking
29+
30+
### 3. Steering Vectors
31+
32+
AutoThink uses pre-extracted steering vectors from datasets like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns:
33+
34+
- **Depth and thoroughness**: Encourages detailed, step-by-step reasoning
35+
- **Numerical accuracy**: Promotes precise calculations and verification
36+
- **Self-correction**: Facilitates error detection and correction
37+
- **Exploration**: Supports considering multiple approaches
38+
- **Organization**: Improves logical structure in responses
39+
40+
During inference, the model's internal activations are modified based on these vectors to enhance specific reasoning capabilities.
41+
42+
### 4. Controlled Thinking Process
43+
44+
The generation process includes:
45+
1. A thinking phase marked by `<think>` and `</think>` tokens
46+
2. Automatic adjustment of thinking time based on query complexity
47+
3. Dynamic application of steering vectors
48+
4. Graceful transition to the final response
49+
50+
## Configuration
51+
52+
AutoThink can be configured with:
53+
54+
```python
55+
{
56+
"model_name": "your-model-name",
57+
"classifier_model": "adaptive-classifier/llm-router",
58+
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
59+
"target_layer": 19, # Layer to apply steering vectors
60+
"high_complexity_min_tokens": 1024,
61+
"high_complexity_max_tokens": 4096,
62+
"low_complexity_min_tokens": 256,
63+
"low_complexity_max_tokens": 1024,
64+
"pattern_strengths": {
65+
"depth_and_thoroughness": 2.5, # Steering strength for different patterns
66+
"numerical_accuracy": 2.0,
67+
"self_correction": 3.0,
68+
"exploration": 2.0,
69+
"organization": 1.5
70+
}
71+
}
72+
```
73+
74+
## Usage
75+
76+
```python
77+
from optillm.autothink import autothink_decode
78+
79+
response = autothink_decode(
80+
model,
81+
tokenizer,
82+
messages,
83+
{
84+
"steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors",
85+
"target_layer": 19
86+
}
87+
)
88+
```
89+
90+
## Benefits
91+
92+
- **Adaptive Resource Usage**: Models think more on complex problems and less on simple ones
93+
- **Enhanced Reasoning**: Steering vectors guide the model toward better reasoning patterns
94+
- **Efficiency**: Better performance without increasing model size
95+
- **Customizability**: Can be tailored for different domains using domain-specific steering vector datasets

optillm/autothink/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
AutoThink - Adaptive thinking approach for LLMs with query complexity classification and steering vectors.
3+
"""
4+
5+
from .autothink import autothink_decode, AutoThinkProcessor
6+
7+
__all__ = ["autothink_decode", "AutoThinkProcessor"]

optillm/autothink/autothink.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
AutoThink main implementation.
3+
4+
This module provides the main implementation of AutoThink, combining
5+
query complexity classification with steering vectors to enhance reasoning.
6+
"""
7+
8+
import logging
9+
from typing import Dict, List, Any, Optional
10+
from transformers import PreTrainedModel, PreTrainedTokenizer
11+
12+
from .processor import AutoThinkProcessor
13+
14+
logger = logging.getLogger(__name__)
15+
16+
class AutoThinkProcessor:
17+
"""
18+
Main AutoThink processor class for external use.
19+
Wraps the internal processor implementation.
20+
"""
21+
22+
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, config: Dict[str, Any] = None):
23+
"""
24+
Initialize the AutoThink processor.
25+
26+
Args:
27+
model: Language model
28+
tokenizer: Model tokenizer
29+
config: Configuration dictionary
30+
"""
31+
self.config = config or {}
32+
self.processor = None
33+
self.model = model
34+
self.tokenizer = tokenizer
35+
36+
def __call__(self, messages: List[Dict[str, str]]) -> str:
37+
"""
38+
Process messages with AutoThink's controlled thinking.
39+
40+
Args:
41+
messages: List of message dictionaries
42+
43+
Returns:
44+
Generated response
45+
"""
46+
# Create processor on first use to allow for model loading
47+
if self.processor is None:
48+
self.processor = self._create_processor()
49+
50+
return self.processor.process(messages)
51+
52+
def _create_processor(self):
53+
"""Create the internal processor instance."""
54+
return AutoThinkProcessor(self.config, self.tokenizer, self.model)
55+
56+
def autothink_decode(
57+
model: PreTrainedModel,
58+
tokenizer: PreTrainedTokenizer,
59+
messages: List[Dict[str, str]],
60+
request_config: Optional[Dict[str, Any]] = None
61+
) -> str:
62+
"""
63+
Main plugin execution function with AutoThink's controlled thinking process.
64+
65+
Args:
66+
model: Language model
67+
tokenizer: Model tokenizer
68+
messages: List of message dictionaries
69+
request_config: Optional configuration dictionary
70+
71+
Returns:
72+
Generated response with thinking process
73+
"""
74+
logger.info("Starting AutoThink processing")
75+
76+
# Create config dictionary
77+
config = {}
78+
if request_config:
79+
config.update(request_config)
80+
81+
try:
82+
processor = AutoThinkProcessor(model, tokenizer, config)
83+
response = processor(messages)
84+
return response
85+
86+
except Exception as e:
87+
logger.error(f"Error in AutoThink processing: {str(e)}")
88+
raise

optillm/autothink/classifier.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
Query complexity classifier for AutoThink.
3+
4+
This module provides functionality to classify queries as HIGH or LOW complexity
5+
using the adaptive-classifier model.
6+
"""
7+
8+
import logging
9+
from typing import Dict, Any, Tuple, Optional, List, Union
10+
import os
11+
import sys
12+
13+
logger = logging.getLogger(__name__)
14+
15+
class ComplexityClassifier:
16+
"""
17+
Classifies queries as HIGH or LOW complexity for token budget allocation.
18+
Uses the adaptive-classifier model for classification.
19+
"""
20+
21+
def __init__(self, model_name: str = "adaptive-classifier/llm-router"):
22+
"""
23+
Initialize the complexity classifier.
24+
25+
Args:
26+
model_name: HuggingFace model name or path for the classifier
27+
"""
28+
self.model_name = model_name
29+
self.classifier = None
30+
31+
# Load model
32+
self._load_model()
33+
34+
def _load_model(self):
35+
"""Load the classification model using adaptive-classifier library."""
36+
try:
37+
# Check if adaptive-classifier is installed
38+
try:
39+
import adaptive_classifier
40+
except ImportError:
41+
logger.info("Installing adaptive-classifier library...")
42+
os.system(f"{sys.executable} -m pip install adaptive-classifier")
43+
import adaptive_classifier
44+
45+
# Import the AdaptiveClassifier class
46+
from adaptive_classifier import AdaptiveClassifier
47+
48+
logger.info(f"Loading complexity classifier model: {self.model_name}")
49+
self.classifier = AdaptiveClassifier.from_pretrained(self.model_name)
50+
logger.info("Classifier loaded successfully")
51+
52+
except Exception as e:
53+
logger.error(f"Error loading complexity classifier: {e}")
54+
# Fallback to basic classification if model fails to load
55+
self.classifier = None
56+
57+
def predict(self, text: str) -> List[Tuple[str, float]]:
58+
"""
59+
Predict the complexity label for a given text.
60+
61+
Args:
62+
text: The query text to classify
63+
64+
Returns:
65+
List of (label, score) tuples sorted by confidence
66+
"""
67+
if self.classifier is None:
68+
logger.warning("Classifier not loaded. Using fallback classification.")
69+
return self._fallback_classification(text)
70+
71+
try:
72+
# Make prediction using the AdaptiveClassifier
73+
predictions = self.classifier.predict(text)
74+
logger.debug(f"Classifier predictions: {predictions}")
75+
76+
# Make sure predictions are in the expected format
77+
if isinstance(predictions, list) and all(isinstance(p, tuple) and len(p) == 2 for p in predictions):
78+
# Sort by confidence (assuming higher score = higher confidence)
79+
predictions.sort(key=lambda x: x[1], reverse=True)
80+
return predictions
81+
else:
82+
logger.warning(f"Unexpected prediction format: {predictions}")
83+
return self._fallback_classification(text)
84+
85+
except Exception as e:
86+
logger.error(f"Error during classification: {e}")
87+
return self._fallback_classification(text)
88+
89+
def _fallback_classification(self, text: str) -> List[Tuple[str, float]]:
90+
"""
91+
Simple heuristic classification when model isn't available.
92+
93+
Args:
94+
text: The query text
95+
96+
Returns:
97+
List of (label, score) tuples
98+
"""
99+
# Count key indicators of complexity
100+
complexity_indicators = [
101+
"explain", "analyze", "compare", "evaluate", "synthesize",
102+
"how", "why", "complex", "detail", "thorough", "comprehensive",
103+
"step by step", "calculate", "prove", "justify", "multiple",
104+
"consequences", "implications", "differentiate", "frameworks"
105+
]
106+
107+
# Count mentions of complexity indicators
108+
count = sum(1 for indicator in complexity_indicators if indicator.lower() in text.lower())
109+
110+
# Calculate complexity probability based on count and text length
111+
text_length_factor = min(len(text) / 100, 2.0) # Cap at 2.0
112+
indicator_factor = min(count / 3, 1.5) # Cap at 1.5
113+
114+
# Combined factor determines HIGH vs LOW
115+
complexity_score = text_length_factor * indicator_factor
116+
117+
if complexity_score > 1.0:
118+
return [("HIGH", 0.7), ("LOW", 0.3)]
119+
else:
120+
return [("LOW", 0.8), ("HIGH", 0.2)]
121+
122+
def is_high_complexity(self, text: str, threshold: float = 0.5) -> bool:
123+
"""
124+
Determine if a query is high complexity.
125+
126+
Args:
127+
text: The query text
128+
threshold: Confidence threshold for HIGH classification
129+
130+
Returns:
131+
Boolean indicating if the query is high complexity
132+
"""
133+
predictions = self.predict(text)
134+
135+
for label, score in predictions:
136+
if label == "HIGH" and score >= threshold:
137+
return True
138+
139+
return False
140+
141+
def get_complexity_with_confidence(self, text: str) -> Tuple[str, float]:
142+
"""
143+
Get the complexity label and confidence score.
144+
145+
Args:
146+
text: The query text
147+
148+
Returns:
149+
Tuple of (complexity_label, confidence_score)
150+
"""
151+
predictions = self.predict(text)
152+
return predictions[0] # Return highest confidence prediction

0 commit comments

Comments
 (0)