Skip to content

Commit b32b303

Browse files
committed
Add DeepConf confidence-aware decoding to OptILLM
Introduces the DeepConf plugin for confidence-aware reasoning with early termination for local models, based on the 'Deep Think with Confidence' paper. Adds core modules for confidence calculation, threshold calibration, and online processing with consensus-based stopping and weighted majority voting. Integrates DeepConf decoding into the inference pipeline and provides a test suite for validation.
1 parent aad79ca commit b32b303

File tree

7 files changed

+1143
-1
lines changed

7 files changed

+1143
-1
lines changed

optillm/deepconf/README.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# DeepConf: Deep Think with Confidence
2+
3+
DeepConf is a confidence-aware reasoning approach for large language models that uses model-internal confidence signals to dynamically filter out low-quality reasoning traces during generation, improving both efficiency and accuracy.
4+
5+
## Overview
6+
7+
Based on the paper "Deep Think with Confidence" by Fu et al. (2024), DeepConf implements:
8+
9+
- **Token-level confidence scoring** using entropy and log-probability metrics
10+
- **Online mode with early termination** to save computational resources
11+
- **Warmup phase for threshold calibration**
12+
- **Consensus-based stopping** when high agreement is reached
13+
- **Weighted majority voting** for final answer selection
14+
15+
## Features
16+
17+
-**Local models only** - Works with OptILLM's local inference engine
18+
-**Two variants**: `low` (aggressive, top 10%) and `high` (conservative, top 90%)
19+
-**Configurable parameters** for different use cases
20+
-**Early termination** to reduce token usage by 50-70%
21+
-**Automatic quality control** without external evaluation
22+
23+
## Usage
24+
25+
### Basic Usage
26+
27+
Set up OptILLM for local inference:
28+
29+
```bash
30+
export OPTILLM_API_KEY=optillm
31+
python optillm.py --model your-local-model
32+
```
33+
34+
Then make a request with DeepConf decoding:
35+
36+
```python
37+
import openai
38+
39+
client = openai.OpenAI(
40+
api_key="optillm",
41+
base_url="http://localhost:8000/v1"
42+
)
43+
44+
response = client.chat.completions.create(
45+
model="your-model",
46+
messages=[
47+
{"role": "user", "content": "Solve this math problem: What is the derivative of x^3 + 2x^2 - 5x + 1?"}
48+
],
49+
extra_body={
50+
"decoding": "deepconf",
51+
"variant": "low", # "low" or "high"
52+
"warmup_samples": 16, # Number of calibration traces
53+
"max_traces": 64, # Maximum total traces
54+
"consensus_threshold": 0.95 # Stop when consensus reached
55+
}
56+
)
57+
58+
print(response.choices[0].message.content)
59+
```
60+
61+
### Configuration Parameters
62+
63+
| Parameter | Default | Description |
64+
|-----------|---------|-------------|
65+
| `variant` | `"low"` | Filtering strategy: `"low"` (top 10%, aggressive) or `"high"` (top 90%, conservative) |
66+
| `warmup_samples` | `16` | Number of initial traces for threshold calibration |
67+
| `consensus_threshold` | `0.95` | Stop generation when this level of agreement is reached |
68+
| `max_traces` | `128` | Maximum number of traces to generate |
69+
| `window_size` | `2048` | Sliding window size for group confidence calculation |
70+
| `top_k` | `5` | Number of top tokens for confidence calculation |
71+
| `min_trace_length` | `100` | Minimum tokens before allowing early termination |
72+
| `max_tokens_per_trace` | `4096` | Maximum tokens per individual trace |
73+
| `confidence_metric` | `"average_confidence"` | Metric used for threshold calculation |
74+
| `include_stats` | `false` | Include processing statistics in response |
75+
76+
### Advanced Usage
77+
78+
Include statistics in the response for debugging:
79+
80+
```python
81+
response = client.chat.completions.create(
82+
model="your-model",
83+
messages=[...],
84+
extra_body={
85+
"decoding": "deepconf",
86+
"variant": "high",
87+
"include_stats": true,
88+
"warmup_samples": 8,
89+
"max_traces": 32
90+
}
91+
)
92+
```
93+
94+
## How It Works
95+
96+
1. **Warmup Phase**: Generate initial traces to calibrate confidence threshold
97+
2. **Online Generation**: Generate traces with early termination based on confidence
98+
3. **Consensus Check**: Stop when sufficient agreement is reached
99+
4. **Final Selection**: Use weighted majority voting to select the best answer
100+
101+
### Confidence Metrics
102+
103+
- **Token Entropy**: `H = -∑P(j) log P(j)`
104+
- **Token Confidence**: `C = -(1/k) ∑log P(j)` for top-k tokens
105+
- **Group Confidence**: Sliding window averages over token confidences
106+
- **Trace Confidence**: Average confidence across all tokens in a trace
107+
108+
### Variants
109+
110+
- **DeepConf-low**: Uses 90th percentile threshold (keeps top 10% traces) - more aggressive filtering
111+
- **DeepConf-high**: Uses 10th percentile threshold (keeps top 90% traces) - more conservative filtering
112+
113+
## Performance
114+
115+
DeepConf typically achieves:
116+
- **50-70% reduction in token usage** compared to standard sampling
117+
- **Maintained or improved accuracy** through confidence-based filtering
118+
- **Automatic quality control** without requiring external evaluation models
119+
120+
## Requirements
121+
122+
- Local model inference (PyTorch)
123+
- OptILLM with `OPTILLM_API_KEY=optillm`
124+
- Compatible with transformer models that provide logits access
125+
126+
## Limitations
127+
128+
- **Local models only** - Cannot work with external API providers (OpenAI, Anthropic, etc.)
129+
- **Requires logits access** - Model must provide token-level probability distributions
130+
- **Not compatible with MLX** - Currently only supports PyTorch-based models
131+
132+
## Testing
133+
134+
Run the test suite to verify the implementation:
135+
136+
```bash
137+
python test_deepconf.py
138+
```
139+
140+
## References
141+
142+
- **Paper**: "Deep Think with Confidence" by Fu et al. (2024)
143+
- **arXiv**: https://arxiv.org/abs/2508.15260
144+
- **Authors**: Yichao Fu (UCSD), Xuewei Wang (Meta AI), Yuandong Tian (Meta AI), Jiawei Zhao (Meta AI)

optillm/deepconf/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
DeepConf plugin for OptILLM
3+
4+
Implements confidence-aware reasoning with early termination for local models.
5+
Based on "Deep Think with Confidence" by Fu et al.
6+
"""
7+
8+
from .deepconf import deepconf_decode
9+
10+
__all__ = ['deepconf_decode']

optillm/deepconf/confidence.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""
2+
Confidence calculation utilities for DeepConf.
3+
4+
Implements various confidence metrics based on token-level probabilities:
5+
- Token Entropy: H = -∑P(j) log P(j)
6+
- Token Confidence: C = -(1/k) ∑log P(j) for top-k tokens
7+
- Group Confidence: Sliding window averages
8+
"""
9+
10+
import torch
11+
import torch.nn.functional as F
12+
import numpy as np
13+
from typing import List, Dict, Tuple, Optional
14+
import logging
15+
16+
logger = logging.getLogger(__name__)
17+
18+
class ConfidenceCalculator:
19+
"""
20+
Calculates various confidence metrics for token-level assessment.
21+
"""
22+
23+
def __init__(self, window_size: int = 2048, top_k: int = 5):
24+
"""
25+
Initialize the confidence calculator.
26+
27+
Args:
28+
window_size: Size of sliding window for group confidence
29+
top_k: Number of top tokens for token confidence calculation
30+
"""
31+
self.window_size = window_size
32+
self.top_k = top_k
33+
self.token_confidences = []
34+
self.group_confidences = []
35+
36+
def reset(self):
37+
"""Reset internal state for new trace."""
38+
self.token_confidences = []
39+
self.group_confidences = []
40+
41+
def calculate_token_entropy(self, logits: torch.Tensor) -> float:
42+
"""
43+
Calculate token entropy: H = -∑P(j) log P(j)
44+
45+
Args:
46+
logits: Raw logits tensor for current token position
47+
48+
Returns:
49+
Token entropy value
50+
"""
51+
probs = F.softmax(logits, dim=-1)
52+
log_probs = F.log_softmax(logits, dim=-1)
53+
54+
# Calculate entropy: -∑P(j) log P(j)
55+
entropy = -(probs * log_probs).sum().item()
56+
57+
return entropy
58+
59+
def calculate_token_confidence(self, logits: torch.Tensor, k: Optional[int] = None) -> float:
60+
"""
61+
Calculate token confidence: C = -(1/k) ∑log P(j) for top-k tokens
62+
63+
Args:
64+
logits: Raw logits tensor for current token position
65+
k: Number of top tokens to consider (default: self.top_k)
66+
67+
Returns:
68+
Token confidence value
69+
"""
70+
if k is None:
71+
k = self.top_k
72+
73+
log_probs = F.log_softmax(logits, dim=-1)
74+
75+
# Get top-k log probabilities
76+
top_log_probs, _ = torch.topk(log_probs, k=k)
77+
78+
# Calculate confidence: -(1/k) ∑log P(j)
79+
confidence = -top_log_probs.mean().item()
80+
81+
return confidence
82+
83+
def add_token_confidence(self, logits: torch.Tensor) -> float:
84+
"""
85+
Add a new token's confidence and update group statistics.
86+
87+
Args:
88+
logits: Raw logits tensor for current token position
89+
90+
Returns:
91+
Token confidence value
92+
"""
93+
confidence = self.calculate_token_confidence(logits)
94+
self.token_confidences.append(confidence)
95+
96+
# Update group confidence if we have enough tokens
97+
if len(self.token_confidences) >= self.window_size:
98+
self._update_group_confidence()
99+
100+
return confidence
101+
102+
def _update_group_confidence(self):
103+
"""Update group confidence based on current sliding window."""
104+
if len(self.token_confidences) < self.window_size:
105+
return
106+
107+
# Calculate group confidence for current window
108+
start_idx = len(self.token_confidences) - self.window_size
109+
window_confidences = self.token_confidences[start_idx:]
110+
group_confidence = np.mean(window_confidences)
111+
112+
self.group_confidences.append(group_confidence)
113+
114+
def get_current_group_confidence(self) -> Optional[float]:
115+
"""
116+
Get the most recent group confidence.
117+
118+
Returns:
119+
Most recent group confidence or None if not available
120+
"""
121+
if not self.group_confidences:
122+
return None
123+
return self.group_confidences[-1]
124+
125+
def get_average_trace_confidence(self) -> float:
126+
"""
127+
Calculate average confidence across all tokens in the trace.
128+
129+
Returns:
130+
Average confidence value
131+
"""
132+
if not self.token_confidences:
133+
return 0.0
134+
return np.mean(self.token_confidences)
135+
136+
def get_bottom_10_percent_confidence(self) -> float:
137+
"""
138+
Calculate average confidence of bottom 10% groups.
139+
140+
Returns:
141+
Bottom 10% group confidence
142+
"""
143+
if not self.group_confidences:
144+
return 0.0
145+
146+
num_bottom = max(1, len(self.group_confidences) // 10)
147+
sorted_confidences = sorted(self.group_confidences)
148+
bottom_confidences = sorted_confidences[:num_bottom]
149+
150+
return np.mean(bottom_confidences)
151+
152+
def get_lowest_group_confidence(self) -> float:
153+
"""
154+
Get the minimum confidence across all groups.
155+
156+
Returns:
157+
Lowest group confidence
158+
"""
159+
if not self.group_confidences:
160+
return 0.0
161+
return min(self.group_confidences)
162+
163+
def get_trace_statistics(self) -> Dict[str, float]:
164+
"""
165+
Get comprehensive confidence statistics for the current trace.
166+
167+
Returns:
168+
Dictionary with various confidence metrics
169+
"""
170+
return {
171+
"average_confidence": self.get_average_trace_confidence(),
172+
"bottom_10_percent": self.get_bottom_10_percent_confidence(),
173+
"lowest_group": self.get_lowest_group_confidence(),
174+
"current_group": self.get_current_group_confidence() or 0.0,
175+
"num_tokens": len(self.token_confidences),
176+
"num_groups": len(self.group_confidences)
177+
}
178+
179+
class ConfidenceThresholdCalibrator:
180+
"""
181+
Calibrates confidence thresholds based on warmup traces.
182+
"""
183+
184+
def __init__(self, variant: str = "low"):
185+
"""
186+
Initialize the threshold calibrator.
187+
188+
Args:
189+
variant: "low" (aggressive, top 10%) or "high" (conservative, top 90%)
190+
"""
191+
self.variant = variant
192+
self.warmup_confidences = []
193+
194+
def add_warmup_trace(self, confidence_stats: Dict[str, float]):
195+
"""
196+
Add confidence statistics from a warmup trace.
197+
198+
Args:
199+
confidence_stats: Dictionary with confidence metrics
200+
"""
201+
self.warmup_confidences.append(confidence_stats)
202+
203+
def calculate_threshold(self, metric: str = "average_confidence") -> float:
204+
"""
205+
Calculate the confidence threshold based on warmup traces.
206+
207+
Args:
208+
metric: Which confidence metric to use for threshold calculation
209+
210+
Returns:
211+
Calculated threshold value
212+
"""
213+
if not self.warmup_confidences:
214+
logger.warning("No warmup traces available for threshold calculation")
215+
return 0.0
216+
217+
confidences = [stats[metric] for stats in self.warmup_confidences]
218+
219+
if self.variant == "low":
220+
# DeepConf-low: 90th percentile (keeps top 10%)
221+
threshold = np.percentile(confidences, 90)
222+
else:
223+
# DeepConf-high: 10th percentile (keeps top 90%)
224+
threshold = np.percentile(confidences, 10)
225+
226+
logger.info(f"Calculated {self.variant} threshold: {threshold:.4f} for metric: {metric}")
227+
return threshold
228+
229+
def should_terminate_trace(self, current_confidence: float, threshold: float) -> bool:
230+
"""
231+
Determine if current trace should be terminated based on confidence.
232+
233+
Args:
234+
current_confidence: Current confidence value
235+
threshold: Threshold for termination
236+
237+
Returns:
238+
True if trace should be terminated
239+
"""
240+
return current_confidence < threshold

0 commit comments

Comments
 (0)