Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions optillm/deepconf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# DeepConf: Deep Think with Confidence

DeepConf is a confidence-aware reasoning approach for large language models that uses model-internal confidence signals to dynamically filter out low-quality reasoning traces during generation, improving both efficiency and accuracy.

## Overview

Based on the paper "Deep Think with Confidence" by Fu et al. (2024), DeepConf implements:

- **Token-level confidence scoring** using entropy and log-probability metrics
- **Online mode with early termination** to save computational resources
- **Warmup phase for threshold calibration**
- **Consensus-based stopping** when high agreement is reached
- **Weighted majority voting** for final answer selection

## Features

- ✅ **Local models only** - Works with OptILLM's local inference engine
- ✅ **Two variants**: `low` (aggressive, top 10%) and `high` (conservative, top 90%)
- ✅ **Configurable parameters** for different use cases
- ✅ **Early termination** to reduce token usage by 50-70%
- ✅ **Automatic quality control** without external evaluation

## Usage

### Basic Usage

Set up OptILLM for local inference:

```bash
export OPTILLM_API_KEY=optillm
python optillm.py --model your-local-model
```

Then make a request with DeepConf decoding:

```python
import openai

client = openai.OpenAI(
api_key="optillm",
base_url="http://localhost:8000/v1"
)

response = client.chat.completions.create(
model="your-model",
messages=[
{"role": "user", "content": "Solve this math problem: What is the derivative of x^3 + 2x^2 - 5x + 1?"}
],
extra_body={
"decoding": "deepconf",
"variant": "low", # "low" or "high"
"warmup_samples": 16, # Number of calibration traces
"max_traces": 64, # Maximum total traces
"consensus_threshold": 0.95 # Stop when consensus reached
}
)

print(response.choices[0].message.content)
```

### Configuration Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `variant` | `"low"` | Filtering strategy: `"low"` (top 10%, aggressive) or `"high"` (top 90%, conservative) |
| `warmup_samples` | `16` | Number of initial traces for threshold calibration |
| `consensus_threshold` | `0.95` | Stop generation when this level of agreement is reached |
| `max_traces` | `128` | Maximum number of traces to generate |
| `window_size` | `2048` | Sliding window size for group confidence calculation |
| `top_k` | `5` | Number of top tokens for confidence calculation |
| `min_trace_length` | `100` | Minimum tokens before allowing early termination |
| `max_tokens_per_trace` | `4096` | Maximum tokens per individual trace |
| `confidence_metric` | `"average_confidence"` | Metric used for threshold calculation |
| `include_stats` | `false` | Include processing statistics in response |

### Advanced Usage

Include statistics in the response for debugging:

```python
response = client.chat.completions.create(
model="your-model",
messages=[...],
extra_body={
"decoding": "deepconf",
"variant": "high",
"include_stats": true,
"warmup_samples": 8,
"max_traces": 32
}
)
```

## How It Works

1. **Warmup Phase**: Generate initial traces to calibrate confidence threshold
2. **Online Generation**: Generate traces with early termination based on confidence
3. **Consensus Check**: Stop when sufficient agreement is reached
4. **Final Selection**: Use weighted majority voting to select the best answer

### Confidence Metrics

- **Token Entropy**: `H = -∑P(j) log P(j)`
- **Token Confidence**: `C = -(1/k) ∑log P(j)` for top-k tokens
- **Group Confidence**: Sliding window averages over token confidences
- **Trace Confidence**: Average confidence across all tokens in a trace

### Variants

- **DeepConf-low**: Uses 90th percentile threshold (keeps top 10% traces) - more aggressive filtering
- **DeepConf-high**: Uses 10th percentile threshold (keeps top 90% traces) - more conservative filtering

## Performance

DeepConf typically achieves:
- **50-70% reduction in token usage** compared to standard sampling
- **Maintained or improved accuracy** through confidence-based filtering
- **Automatic quality control** without requiring external evaluation models

## Requirements

- Local model inference (PyTorch)
- OptILLM with `OPTILLM_API_KEY=optillm`
- Compatible with transformer models that provide logits access

## Limitations

- **Local models only** - Cannot work with external API providers (OpenAI, Anthropic, etc.)
- **Requires logits access** - Model must provide token-level probability distributions
- **Not compatible with MLX** - Currently only supports PyTorch-based models

## Testing

Run the test suite to verify the implementation:

```bash
python test_deepconf.py
```

## References

- **Paper**: "Deep Think with Confidence" by Fu et al. (2024)
- **arXiv**: https://arxiv.org/abs/2508.15260
- **Authors**: Yichao Fu (UCSD), Xuewei Wang (Meta AI), Yuandong Tian (Meta AI), Jiawei Zhao (Meta AI)
10 changes: 10 additions & 0 deletions optillm/deepconf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
DeepConf plugin for OptILLM

Implements confidence-aware reasoning with early termination for local models.
Based on "Deep Think with Confidence" by Fu et al.
"""

from .deepconf import deepconf_decode

__all__ = ['deepconf_decode']
240 changes: 240 additions & 0 deletions optillm/deepconf/confidence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
Confidence calculation utilities for DeepConf.

Implements various confidence metrics based on token-level probabilities:
- Token Entropy: H = -∑P(j) log P(j)
- Token Confidence: C = -(1/k) ∑log P(j) for top-k tokens
- Group Confidence: Sliding window averages
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional
import logging

logger = logging.getLogger(__name__)

class ConfidenceCalculator:
"""
Calculates various confidence metrics for token-level assessment.
"""

def __init__(self, window_size: int = 2048, top_k: int = 5):
"""
Initialize the confidence calculator.

Args:
window_size: Size of sliding window for group confidence
top_k: Number of top tokens for token confidence calculation
"""
self.window_size = window_size
self.top_k = top_k
self.token_confidences = []
self.group_confidences = []

def reset(self):
"""Reset internal state for new trace."""
self.token_confidences = []
self.group_confidences = []

def calculate_token_entropy(self, logits: torch.Tensor) -> float:
"""
Calculate token entropy: H = -∑P(j) log P(j)

Args:
logits: Raw logits tensor for current token position

Returns:
Token entropy value
"""
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)

# Calculate entropy: -∑P(j) log P(j)
entropy = -(probs * log_probs).sum().item()

return entropy

def calculate_token_confidence(self, logits: torch.Tensor, k: Optional[int] = None) -> float:
"""
Calculate token confidence: C = -(1/k) ∑log P(j) for top-k tokens

Args:
logits: Raw logits tensor for current token position
k: Number of top tokens to consider (default: self.top_k)

Returns:
Token confidence value
"""
if k is None:
k = self.top_k

log_probs = F.log_softmax(logits, dim=-1)

# Get top-k log probabilities
top_log_probs, _ = torch.topk(log_probs, k=k)

# Calculate confidence: -(1/k) ∑log P(j)
confidence = -top_log_probs.mean().item()

return confidence

def add_token_confidence(self, logits: torch.Tensor) -> float:
"""
Add a new token's confidence and update group statistics.

Args:
logits: Raw logits tensor for current token position

Returns:
Token confidence value
"""
confidence = self.calculate_token_confidence(logits)
self.token_confidences.append(confidence)

# Update group confidence if we have enough tokens
if len(self.token_confidences) >= self.window_size:
self._update_group_confidence()

return confidence

def _update_group_confidence(self):
"""Update group confidence based on current sliding window."""
if len(self.token_confidences) < self.window_size:
return

# Calculate group confidence for current window
start_idx = len(self.token_confidences) - self.window_size
window_confidences = self.token_confidences[start_idx:]
group_confidence = np.mean(window_confidences)

self.group_confidences.append(group_confidence)

def get_current_group_confidence(self) -> Optional[float]:
"""
Get the most recent group confidence.

Returns:
Most recent group confidence or None if not available
"""
if not self.group_confidences:
return None
return self.group_confidences[-1]

def get_average_trace_confidence(self) -> float:
"""
Calculate average confidence across all tokens in the trace.

Returns:
Average confidence value
"""
if not self.token_confidences:
return 0.0
return np.mean(self.token_confidences)

def get_bottom_10_percent_confidence(self) -> float:
"""
Calculate average confidence of bottom 10% groups.

Returns:
Bottom 10% group confidence
"""
if not self.group_confidences:
return 0.0

num_bottom = max(1, len(self.group_confidences) // 10)
sorted_confidences = sorted(self.group_confidences)
bottom_confidences = sorted_confidences[:num_bottom]

return np.mean(bottom_confidences)

def get_lowest_group_confidence(self) -> float:
"""
Get the minimum confidence across all groups.

Returns:
Lowest group confidence
"""
if not self.group_confidences:
return 0.0
return min(self.group_confidences)

def get_trace_statistics(self) -> Dict[str, float]:
"""
Get comprehensive confidence statistics for the current trace.

Returns:
Dictionary with various confidence metrics
"""
return {
"average_confidence": self.get_average_trace_confidence(),
"bottom_10_percent": self.get_bottom_10_percent_confidence(),
"lowest_group": self.get_lowest_group_confidence(),
"current_group": self.get_current_group_confidence() or 0.0,
"num_tokens": len(self.token_confidences),
"num_groups": len(self.group_confidences)
}

class ConfidenceThresholdCalibrator:
"""
Calibrates confidence thresholds based on warmup traces.
"""

def __init__(self, variant: str = "low"):
"""
Initialize the threshold calibrator.

Args:
variant: "low" (aggressive, top 10%) or "high" (conservative, top 90%)
"""
self.variant = variant
self.warmup_confidences = []

def add_warmup_trace(self, confidence_stats: Dict[str, float]):
"""
Add confidence statistics from a warmup trace.

Args:
confidence_stats: Dictionary with confidence metrics
"""
self.warmup_confidences.append(confidence_stats)

def calculate_threshold(self, metric: str = "average_confidence") -> float:
"""
Calculate the confidence threshold based on warmup traces.

Args:
metric: Which confidence metric to use for threshold calculation

Returns:
Calculated threshold value
"""
if not self.warmup_confidences:
logger.warning("No warmup traces available for threshold calculation")
return 0.0

confidences = [stats[metric] for stats in self.warmup_confidences]

if self.variant == "low":
# DeepConf-low: 90th percentile (keeps top 10%)
threshold = np.percentile(confidences, 90)
else:
# DeepConf-high: 10th percentile (keeps top 90%)
threshold = np.percentile(confidences, 10)

logger.info(f"Calculated {self.variant} threshold: {threshold:.4f} for metric: {metric}")
return threshold

def should_terminate_trace(self, current_confidence: float, threshold: float) -> bool:
"""
Determine if current trace should be terminated based on confidence.

Args:
current_confidence: Current confidence value
threshold: Threshold for termination

Returns:
True if trace should be terminated
"""
return current_confidence < threshold
Loading