Skip to content

Commit 7083aaa

Browse files
committed
implement ECE as a diagnostics metric
1 parent 95d98c6 commit 7083aaa

File tree

3 files changed

+113
-9
lines changed

3 files changed

+113
-9
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .calibration_error import calibration_error
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
4+
from .expected_calibration_error import expected_calibration_error
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
from keras import ops
3+
from typing import Sequence, Any, Mapping
4+
5+
from ...utils.exceptions import ShapeError
6+
from sklearn.calibration import calibration_curve
7+
8+
9+
def expected_calibration_error(
10+
estimates: np.ndarray,
11+
targets: np.ndarray,
12+
model_names: Sequence[str] = None,
13+
n_bins: int = 10,
14+
return_probs: bool = False,
15+
) -> Mapping[str, Any]:
16+
"""Estimates the expected calibration error (ECE) of a model comparison network according to [1].
17+
18+
[1] Naeini, M. P., Cooper, G., & Hauskrecht, M. (2015).
19+
Obtaining well calibrated probabilities using bayesian binning.
20+
In Proceedings of the AAAI conference on artificial intelligence (Vol. 29, No. 1).
21+
22+
Notes
23+
-----
24+
Make sure that ``targets`` are **one-hot encoded** classes!
25+
26+
Parameters
27+
----------
28+
estimates : array of shape (num_sim, num_models)
29+
The predicted posterior model probabilities.
30+
targets : array of shape (num_sim, num_models)
31+
The one-hot-encoded true model indices.
32+
model_names : Sequence[str], optional (default = None)
33+
Optional model names to show in the output. By default, models are called "M_" + model index.
34+
n_bins : int, optional, default: 10
35+
The number of bins to use for the calibration curves (and marginal histograms).
36+
Passed into ``sklearn.calibration.calibration_curve()``.
37+
return_probs : bool (default = False)
38+
Do you want to obtain the output of ``sklearn.calibration.calibration_curve()``?
39+
40+
Returns
41+
-------
42+
result : dict
43+
Dictionary containing:
44+
- "values" : float or np.ndarray
45+
The expected calibration error per model
46+
- "metric_name" : str
47+
The name of the metric ("Expected Calibration Error").
48+
- "model_names" : str
49+
The (inferred) variable names.
50+
- "probs_true": (optional) list:
51+
Outputs of ``sklearn.calibration.calibration_curve()`` per model
52+
- "probs_pred": (optional) list:
53+
Outputs of ``sklearn.calibration.calibration_curve()`` per model
54+
"""
55+
56+
# Convert tensors to numpy, if passed
57+
estimates = ops.convert_to_numpy(estimates)
58+
targets = ops.convert_to_numpy(targets)
59+
60+
if estimates.shape != targets.shape:
61+
raise ShapeError("`estimates` and `targets` must have the same shape.")
62+
63+
if model_names is None:
64+
model_names = ["M_" + str(i) for i in range(estimates.shape[-1])]
65+
elif len(model_names) != estimates.shape[-1]:
66+
raise ShapeError("There must be exactly one `model_name` for each model in `estimates`")
67+
68+
# Extract number of models and prepare containers
69+
ece = []
70+
probs_true = []
71+
probs_pred = []
72+
73+
targets = targets.argmax(axis=-1)
74+
75+
# Loop for each model and compute calibration errs per bin
76+
for model_index in range(estimates.shape[-1]):
77+
y_true = (targets == model_index).astype(np.float32)
78+
y_prob = estimates[..., model_index]
79+
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
80+
81+
# Compute ECE by weighting bin errors by bin size
82+
bins = np.linspace(0.0, 1.0, n_bins + 1)
83+
binids = np.searchsorted(bins[1:-1], y_prob)
84+
bin_total = np.bincount(binids, minlength=len(bins))
85+
nonzero = bin_total != 0
86+
error = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))
87+
88+
ece.append(error)
89+
probs_true.append(prob_true)
90+
probs_pred.append(prob_pred)
91+
92+
output = dict(values=ece, metric_name="Expected Calibration Error", model_names=model_names)
93+
94+
if return_probs:
95+
output["probs_true"] = probs_true
96+
output["probs_pred"] = probs_pred
97+
98+
return output

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55

66

77
from bayesflow.utils import (
8-
expected_calibration_error,
98
prepare_plot_data,
109
add_titles_and_labels,
1110
add_metric,
1211
prettify_subplots,
1312
)
1413

14+
from bayesflow.diagnostics.metrics import expected_calibration_error
15+
1516

1617
def mc_calibration(
1718
pred_models: dict[str, np.ndarray] | np.ndarray,
1819
true_models: dict[str, np.ndarray] | np.ndarray,
1920
model_names: Sequence[str] = None,
20-
num_bins: int = 10,
21+
n_bins: int = 10,
2122
label_fontsize: int = 16,
2223
title_fontsize: int = 18,
2324
metric_fontsize: int = 14,
@@ -40,7 +41,7 @@ def mc_calibration(
4041
The one-hot-encoded true model indices per data set.
4142
model_names : list or None, optional, default: None
4243
The model names for nice plot titles. Inferred if None.
43-
num_bins : int, optional, default: 10
44+
n_bins : int, optional, default: 10
4445
The number of bins to use for the calibration curves (and marginal histograms).
4546
label_fontsize : int, optional, default: 16
4647
The font size of the y-label and y-label texts
@@ -77,17 +78,21 @@ def mc_calibration(
7778
default_name="M",
7879
)
7980

80-
# Compute calibration
81-
cal_errors, true_probs, pred_probs = expected_calibration_error(
82-
plot_data["targets"], plot_data["estimates"], num_bins
81+
# compute ece and probs
82+
ece = expected_calibration_error(
83+
estimates=pred_models,
84+
targets=true_models,
85+
model_names=plot_data["variable_names"],
86+
n_bins=n_bins,
87+
return_probs=True,
8388
)
8489

8590
for j, ax in enumerate(plot_data["axes"].flat):
8691
# Plot calibration curve
87-
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)
92+
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
8893

8994
# Plot PMP distribution over bins
90-
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
95+
uniform_bins = np.linspace(0.0, 1.0, n_bins + 1)
9196
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
9297
ax.hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
9398

@@ -104,7 +109,7 @@ def mc_calibration(
104109
add_metric(
105110
ax,
106111
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
107-
metric_value=cal_errors[j],
112+
metric_value=ece["values"][j],
108113
metric_fontsize=metric_fontsize,
109114
)
110115

0 commit comments

Comments
 (0)