Skip to content

Commit 5ac7c99

Browse files
Merge pull request #336 from Kucharssim/tests-model-comparison-diagnostics
Tests for model comparison diagnostics + ece as a diagnostic metric
2 parents 95d98c6 + f6fbef1 commit 5ac7c99

File tree

8 files changed

+172
-72
lines changed

8 files changed

+172
-72
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .calibration_error import calibration_error
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
4+
from .expected_calibration_error import expected_calibration_error
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
from keras import ops
3+
from typing import Sequence, Any, Mapping
4+
5+
from ...utils.exceptions import ShapeError
6+
from sklearn.calibration import calibration_curve
7+
8+
9+
def expected_calibration_error(
10+
estimates: np.ndarray,
11+
targets: np.ndarray,
12+
model_names: Sequence[str] = None,
13+
n_bins: int = 10,
14+
return_probs: bool = False,
15+
) -> Mapping[str, Any]:
16+
"""Estimates the expected calibration error (ECE) of a model comparison network according to [1].
17+
18+
[1] Naeini, M. P., Cooper, G., & Hauskrecht, M. (2015).
19+
Obtaining well calibrated probabilities using bayesian binning.
20+
In Proceedings of the AAAI conference on artificial intelligence (Vol. 29, No. 1).
21+
22+
Notes
23+
-----
24+
Make sure that ``targets`` are **one-hot encoded** classes!
25+
26+
Parameters
27+
----------
28+
estimates : array of shape (num_sim, num_models)
29+
The predicted posterior model probabilities.
30+
targets : array of shape (num_sim, num_models)
31+
The one-hot-encoded true model indices.
32+
model_names : Sequence[str], optional (default = None)
33+
Optional model names to show in the output. By default, models are called "M_" + model index.
34+
n_bins : int, optional, default: 10
35+
The number of bins to use for the calibration curves (and marginal histograms).
36+
Passed into ``sklearn.calibration.calibration_curve()``.
37+
return_probs : bool (default = False)
38+
Do you want to obtain the output of ``sklearn.calibration.calibration_curve()``?
39+
40+
Returns
41+
-------
42+
result : dict
43+
Dictionary containing:
44+
- "values" : np.ndarray
45+
The expected calibration error per model
46+
- "metric_name" : str
47+
The name of the metric ("Expected Calibration Error").
48+
- "model_names" : str
49+
The (inferred) variable names.
50+
- "probs_true": (optional) list[np.ndarray]:
51+
Outputs of ``sklearn.calibration.calibration_curve()`` per model
52+
- "probs_pred": (optional) list[np.ndarray]:
53+
Outputs of ``sklearn.calibration.calibration_curve()`` per model
54+
"""
55+
56+
# Convert tensors to numpy, if passed
57+
estimates = ops.convert_to_numpy(estimates)
58+
targets = ops.convert_to_numpy(targets)
59+
60+
if estimates.shape != targets.shape:
61+
raise ShapeError("`estimates` and `targets` must have the same shape.")
62+
63+
if model_names is None:
64+
model_names = ["M_" + str(i) for i in range(estimates.shape[-1])]
65+
elif len(model_names) != estimates.shape[-1]:
66+
raise ShapeError("There must be exactly one `model_name` for each model in `estimates`")
67+
68+
# Extract number of models and prepare containers
69+
ece = []
70+
probs_true = []
71+
probs_pred = []
72+
73+
targets = targets.argmax(axis=-1)
74+
75+
# Loop for each model and compute calibration errs per bin
76+
for model_index in range(estimates.shape[-1]):
77+
y_true = (targets == model_index).astype(np.float32)
78+
y_prob = estimates[..., model_index]
79+
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
80+
81+
# Compute ECE by weighting bin errors by bin size
82+
bins = np.linspace(0.0, 1.0, n_bins + 1)
83+
binids = np.searchsorted(bins[1:-1], y_prob)
84+
bin_total = np.bincount(binids, minlength=len(bins))
85+
nonzero = bin_total != 0
86+
error = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))
87+
88+
ece.append(error)
89+
probs_true.append(prob_true)
90+
probs_pred.append(prob_pred)
91+
92+
output = dict(values=np.array(ece), metric_name="Expected Calibration Error", model_names=model_names)
93+
94+
if return_probs:
95+
output["probs_true"] = probs_true
96+
output["probs_pred"] = probs_pred
97+
98+
return output

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,20 @@
55

66

77
from bayesflow.utils import (
8-
expected_calibration_error,
98
prepare_plot_data,
109
add_titles_and_labels,
1110
add_metric,
1211
prettify_subplots,
1312
)
1413

14+
from bayesflow.diagnostics.metrics import expected_calibration_error
15+
1516

1617
def mc_calibration(
1718
pred_models: dict[str, np.ndarray] | np.ndarray,
1819
true_models: dict[str, np.ndarray] | np.ndarray,
1920
model_names: Sequence[str] = None,
20-
num_bins: int = 10,
21+
n_bins: int = 10,
2122
label_fontsize: int = 16,
2223
title_fontsize: int = 18,
2324
metric_fontsize: int = 14,
@@ -40,7 +41,7 @@ def mc_calibration(
4041
The one-hot-encoded true model indices per data set.
4142
model_names : list or None, optional, default: None
4243
The model names for nice plot titles. Inferred if None.
43-
num_bins : int, optional, default: 10
44+
n_bins : int, optional, default: 10
4445
The number of bins to use for the calibration curves (and marginal histograms).
4546
label_fontsize : int, optional, default: 16
4647
The font size of the y-label and y-label texts
@@ -77,17 +78,21 @@ def mc_calibration(
7778
default_name="M",
7879
)
7980

80-
# Compute calibration
81-
cal_errors, true_probs, pred_probs = expected_calibration_error(
82-
plot_data["targets"], plot_data["estimates"], num_bins
81+
# compute ece and probs
82+
ece = expected_calibration_error(
83+
estimates=pred_models,
84+
targets=true_models,
85+
model_names=plot_data["variable_names"],
86+
n_bins=n_bins,
87+
return_probs=True,
8388
)
8489

8590
for j, ax in enumerate(plot_data["axes"].flat):
8691
# Plot calibration curve
87-
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)
92+
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
8893

8994
# Plot PMP distribution over bins
90-
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
95+
uniform_bins = np.linspace(0.0, 1.0, n_bins + 1)
9196
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
9297
ax.hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)
9398

@@ -104,7 +109,7 @@ def mc_calibration(
104109
add_metric(
105110
ax,
106111
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
107-
metric_value=cal_errors[j],
112+
metric_value=ece["values"][j],
108113
metric_fontsize=metric_fontsize,
109114
)
110115

bayesflow/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
numpy_utils,
55
)
66
from .callbacks import detailed_loss_callback
7-
from .comp_utils import expected_calibration_error
87
from .devices import devices
98
from .dict_utils import (
109
convert_args,

bayesflow/utils/comp_utils.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

tests/test_diagnostics/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from bayesflow.utils.numpy_utils import softmax
34

45

56
@pytest.fixture()
@@ -31,3 +32,22 @@ def random_priors():
3132
"sigma": np.random.standard_normal(size=(64, 1)),
3233
"y": np.random.standard_normal(size=(64, 3, 1)),
3334
}
35+
36+
37+
@pytest.fixture()
38+
def model_names():
39+
return [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]
40+
41+
42+
@pytest.fixture()
43+
def true_models():
44+
true_models = np.random.choice(3, 100)
45+
true_models = np.eye(3)[true_models].astype(np.int32)
46+
return true_models
47+
48+
49+
@pytest.fixture()
50+
def pred_models(true_models):
51+
pred_models = np.random.normal(loc=true_models)
52+
pred_models = softmax(pred_models, axis=-1)
53+
return pred_models

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bayesflow as bf
2+
import pytest
23

34

45
def num_variables(x: dict):
@@ -47,3 +48,26 @@ def test_root_mean_squared_error(random_estimates, random_targets):
4748
assert out["values"].shape == (num_variables(random_estimates),)
4849
assert out["metric_name"] == "NRMSE"
4950
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
51+
52+
53+
def test_expected_calibration_error(pred_models, true_models, model_names):
54+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=model_names)
55+
assert list(out.keys()) == ["values", "metric_name", "model_names"]
56+
assert out["values"].shape == (pred_models.shape[-1],)
57+
assert out["metric_name"] == "Expected Calibration Error"
58+
assert out["model_names"] == [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]
59+
60+
# returns probs?
61+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, return_probs=True)
62+
assert list(out.keys()) == ["values", "metric_name", "model_names", "probs_true", "probs_pred"]
63+
assert len(out["probs_true"]) == pred_models.shape[-1]
64+
assert len(out["probs_pred"]) == pred_models.shape[-1]
65+
# default: auto model names
66+
assert out["model_names"] == ["M_0", "M_1", "M_2"]
67+
68+
# handles incorrect input?
69+
with pytest.raises(Exception):
70+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=["a"])
71+
72+
with pytest.raises(Exception):
73+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,18 @@ def test_pairs_posterior(random_estimates, random_targets, random_priors):
102102
priors=random_priors,
103103
dataset_id=[1, 3],
104104
)
105+
106+
107+
def test_mc_calibration(pred_models, true_models, model_names):
108+
out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names)
109+
assert len(out.axes) == pred_models.shape[-1]
110+
assert out.axes[0].get_ylabel() == "True Probability"
111+
assert out.axes[0].get_xlabel() == "Predicted Probability"
112+
assert out.axes[-1].get_title() == r"$\mathcal{M}_2$"
113+
114+
115+
def test_mc_confusion_matrix(pred_models, true_models, model_names):
116+
out = bf.diagnostics.plots.mc_confusion_matrix(pred_models, true_models, model_names, normalize="true")
117+
assert out.axes[0].get_ylabel() == "True model"
118+
assert out.axes[0].get_xlabel() == "Predicted model"
119+
assert out.axes[0].get_title() == "Confusion Matrix"

0 commit comments

Comments
 (0)