Skip to content

Commit 7f09744

Browse files
committed
add test: expected calibration error metric
1 parent edf91c2 commit 7f09744

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/test_diagnostics/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from bayesflow.utils.numpy_utils import softmax
34

45

56
@pytest.fixture()
@@ -31,3 +32,22 @@ def random_priors():
3132
"sigma": np.random.standard_normal(size=(64, 1)),
3233
"y": np.random.standard_normal(size=(64, 3, 1)),
3334
}
35+
36+
37+
@pytest.fixture()
38+
def model_names():
39+
return [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]
40+
41+
42+
@pytest.fixture()
43+
def true_models():
44+
true_models = np.random.choice(3, 100)
45+
true_models = np.eye(3)[true_models].astype(np.int32)
46+
return true_models
47+
48+
49+
@pytest.fixture()
50+
def pred_models(true_models):
51+
pred_models = np.random.normal(loc=true_models)
52+
pred_models = softmax(pred_models, axis=-1)
53+
return pred_models

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bayesflow as bf
2+
import pytest
23

34

45
def num_variables(x: dict):
@@ -47,3 +48,26 @@ def test_root_mean_squared_error(random_estimates, random_targets):
4748
assert out["values"].shape == (num_variables(random_estimates),)
4849
assert out["metric_name"] == "NRMSE"
4950
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
51+
52+
53+
def test_expected_calibration_error(pred_models, true_models, model_names):
54+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=model_names)
55+
assert list(out.keys()) == ["values", "metric_name", "model_names"]
56+
assert out["values"].shape == (pred_models.shape[-1],)
57+
assert out["metric_name"] == "Expected Calibration Error"
58+
assert out["model_names"] == [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]
59+
60+
# returns probs?
61+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, return_probs=True)
62+
assert list(out.keys()) == ["values", "metric_name", "model_names", "probs_true", "probs_pred"]
63+
assert len(out["probs_true"]) == pred_models.shape[-1]
64+
assert len(out["probs_pred"]) == pred_models.shape[-1]
65+
# default: auto model names
66+
assert out["model_names"] == ["M_0", "M_1", "M_2"]
67+
68+
# handles incorrect input?
69+
with pytest.raises(Exception):
70+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=["a"])
71+
72+
with pytest.raises(Exception):
73+
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)

0 commit comments

Comments
 (0)