Skip to content

Commit f619d63

Browse files
HFooladiclaude
andcommitted
feat: implement Calibration Metrics for uncertainty quantification
Add comprehensive calibration metrics module (molax/metrics) including: Metrics: - negative_log_likelihood: Proper scoring rule for probabilistic predictions - expected_calibration_error: Average gap between confidence and accuracy - compute_calibration_curve: Data for reliability diagrams - sharpness: Average predicted uncertainty - calibration_error_per_sample: Per-sample z-scores - evaluate_calibration: Comprehensive metrics in one call Calibration: - TemperatureScaling: Post-hoc calibration via temperature optimization on validation set to minimize NLL Visualization: - plot_reliability_diagram: Calibration quality visualization - plot_calibration_comparison: Compare multiple models side-by-side - plot_uncertainty_vs_error: Scatter of predicted vs actual uncertainty - plot_confidence_histogram: Distribution of predicted uncertainties - plot_z_score_histogram: Z-score distribution vs expected N(0,1) - create_calibration_report: Comprehensive multi-plot report Also includes: - 43 comprehensive tests (all passing) - Example script comparing MC Dropout, Ensemble, and Evidential calibration - Updated roadmap marking 1.3 as complete Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ea3f94a commit f619d63

File tree

6 files changed

+2036
-53
lines changed

6 files changed

+2036
-53
lines changed

docs/roadmap.md

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def evidential_uncertainty(nu, alpha, beta):
141141

142142
---
143143

144-
### 1.3 Calibration Metrics
144+
### 1.3 Calibration Metrics ✅
145+
146+
**Status:** Implemented in `molax/metrics/`
145147

146148
**What:** Quantify how well predicted uncertainties match actual error frequencies.
147149

@@ -151,61 +153,48 @@ def evidential_uncertainty(nu, alpha, beta):
151153

152154
```python
153155
# molax/metrics/calibration.py
154-
import jax.numpy as jnp
155-
156-
def expected_calibration_error(
157-
predictions: jnp.ndarray,
158-
uncertainties: jnp.ndarray,
159-
targets: jnp.ndarray,
160-
n_bins: int = 10
161-
) -> float:
162-
"""
163-
ECE: Average gap between confidence and accuracy across bins.
164-
Lower is better. Perfect calibration = 0.
165-
"""
166-
errors = jnp.abs(predictions - targets)
167-
confidences = 1.0 / (1.0 + uncertainties) # Convert variance to confidence
168-
169-
bin_boundaries = jnp.linspace(0, 1, n_bins + 1)
170-
ece = 0.0
171-
172-
for i in range(n_bins):
173-
mask = (confidences >= bin_boundaries[i]) & (confidences < bin_boundaries[i+1])
174-
if jnp.sum(mask) > 0:
175-
bin_confidence = jnp.mean(confidences[mask])
176-
bin_accuracy = 1.0 - jnp.mean(errors[mask]) # Normalized
177-
ece += jnp.sum(mask) * jnp.abs(bin_accuracy - bin_confidence)
178-
179-
return ece / len(predictions)
180-
181-
def reliability_diagram_data(
182-
predictions: jnp.ndarray,
183-
uncertainties: jnp.ndarray,
184-
targets: jnp.ndarray,
185-
n_bins: int = 10
186-
) -> dict:
187-
"""Returns data for plotting reliability diagrams."""
188-
# ... bin confidences and accuracies for visualization
189-
pass
190-
191-
def negative_log_likelihood(mean, var, targets):
192-
"""Proper scoring rule for probabilistic predictions."""
193-
return 0.5 * (jnp.log(2 * jnp.pi * var) + (targets - mean)**2 / var)
194-
195-
def calibration_temperature_scaling(
196-
val_predictions, val_uncertainties, val_targets
197-
) -> float:
198-
"""Learn temperature T to scale uncertainties for calibration."""
199-
# ... optimize T to minimize NLL on validation set
200-
pass
156+
from molax.metrics import (
157+
expected_calibration_error,
158+
negative_log_likelihood,
159+
compute_calibration_curve,
160+
sharpness,
161+
evaluate_calibration,
162+
TemperatureScaling,
163+
plot_reliability_diagram,
164+
plot_calibration_comparison,
165+
create_calibration_report,
166+
)
167+
168+
# Compute ECE
169+
ece = expected_calibration_error(predictions, uncertainties, targets, n_bins=10)
170+
171+
# Compute NLL (proper scoring rule)
172+
nll = negative_log_likelihood(mean, var, targets)
173+
174+
# Comprehensive evaluation
175+
metrics = evaluate_calibration(mean, var, targets)
176+
# Returns: {'nll': ..., 'ece': ..., 'rmse': ..., 'sharpness': ..., 'mean_z_score': ...}
177+
178+
# Temperature scaling for post-hoc calibration
179+
scaler = TemperatureScaling()
180+
scaler.fit(val_mean, val_var, val_targets)
181+
calibrated_var = scaler.transform(test_var)
182+
print(f"Learned temperature: {scaler.temperature}")
183+
184+
# Visualization
185+
plot_reliability_diagram(predictions, uncertainties, targets)
186+
fig = plot_calibration_comparison({
187+
"Model A": (preds_a, var_a, targets),
188+
"Model B": (preds_b, var_b, targets),
189+
})
201190
```
202191

203192
**Acceptance Criteria:**
204-
- [ ] ECE computation (Expected Calibration Error)
205-
- [ ] Reliability diagram plotting utility
206-
- [ ] NLL as proper scoring rule
207-
- [ ] Temperature scaling for post-hoc calibration
208-
- [ ] Integration into evaluation pipeline
193+
- [x] ECE computation (Expected Calibration Error)
194+
- [x] Reliability diagram plotting utility
195+
- [x] NLL as proper scoring rule
196+
- [x] Temperature scaling for post-hoc calibration
197+
- [x] Integration into evaluation pipeline
209198

210199
---
211200

0 commit comments

Comments
 (0)