Skip to content

Commit 95d98c6

Browse files
Merge pull request #335 from bayesflow-org/test-diagnostics
Add initial tests for diagnostic metrics
2 parents 93ca6a9 + 900fafd commit 95d98c6

File tree

6 files changed

+185
-5
lines changed

6 files changed

+185
-5
lines changed

bayesflow/utils/dict_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str,
124124
splits = [np.squeeze(split, axis=axis) for split in splits]
125125

126126
for i, split in enumerate(splits):
127-
result[f"{key}_{i + 1}"] = split
127+
result[f"{key}_{i}"] = split
128128

129129
return result
130130

@@ -214,7 +214,7 @@ def make_variable_array(
214214

215215
# use default names if not otherwise specified
216216
if variable_names is None:
217-
variable_names = [f"${default_name}_{{{i}}}$" for i in range(x.shape[-1])]
217+
variable_names = [f"{default_name}_{i}" for i in range(x.shape[-1])]
218218

219219
if dataset_ids is not None:
220220
x = x[dataset_ids]

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def pytest_make_parametrize_id(config, val, argname):
2222
return f"{argname}={repr(val)}"
2323

2424

25-
@pytest.fixture(params=[2, 3], scope="session", autouse=True)
25+
@pytest.fixture(params=[2, 3], scope="session")
2626
def batch_size(request):
2727
return request.param
2828

tests/test_diagnostics/conftest.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
1+
import numpy as np
12
import pytest
23

34

45
@pytest.fixture()
5-
def num_samples():
6-
return 1000
6+
def var_names():
7+
return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"]
8+
9+
10+
@pytest.fixture()
11+
def random_estimates():
12+
return {
13+
"beta": np.random.standard_normal(size=(32, 10, 2)),
14+
"sigma": np.random.standard_normal(size=(32, 10, 1)),
15+
}
16+
17+
18+
@pytest.fixture()
19+
def random_targets():
20+
return {
21+
"beta": np.random.standard_normal(size=(32, 2)),
22+
"sigma": np.random.standard_normal(size=(32, 1)),
23+
"y": np.random.standard_normal(size=(32, 3, 1)),
24+
}
25+
26+
27+
@pytest.fixture()
28+
def random_priors():
29+
return {
30+
"beta": np.random.standard_normal(size=(64, 2)),
31+
"sigma": np.random.standard_normal(size=(64, 1)),
32+
"y": np.random.standard_normal(size=(64, 3, 1)),
33+
}

tests/test_diagnostics/test_diagnostics.py

Whitespace-only changes.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import bayesflow as bf
2+
3+
4+
def num_variables(x: dict):
5+
return sum(arr.shape[-1] for arr in x.values())
6+
7+
8+
def test_metric_calibration_error(random_estimates, random_targets, var_names):
9+
# basic functionality: automatic variable names
10+
out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets)
11+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
12+
assert out["values"].shape == (num_variables(random_estimates),)
13+
assert out["metric_name"] == "Calibration Error"
14+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
15+
16+
# user specified variable names
17+
out = bf.diagnostics.metrics.calibration_error(
18+
estimates=random_estimates,
19+
targets=random_targets,
20+
variable_names=var_names,
21+
)
22+
assert out["variable_names"] == var_names
23+
24+
# user-specifed keys and scalar variable
25+
out = bf.diagnostics.metrics.calibration_error(
26+
estimates=random_estimates,
27+
targets=random_targets,
28+
variable_keys="sigma",
29+
)
30+
assert out["values"].shape == (random_estimates["sigma"].shape[-1],)
31+
assert out["variable_names"] == ["sigma"]
32+
33+
34+
def test_posterior_contraction(random_estimates, random_targets):
35+
# basic functionality: automatic variable names
36+
out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets)
37+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
38+
assert out["values"].shape == (num_variables(random_estimates),)
39+
assert out["metric_name"] == "Posterior Contraction"
40+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
41+
42+
43+
def test_root_mean_squared_error(random_estimates, random_targets):
44+
# basic functionality: automatic variable names
45+
out = bf.diagnostics.metrics.root_mean_squared_error(random_estimates, random_targets)
46+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
47+
assert out["values"].shape == (num_variables(random_estimates),)
48+
assert out["metric_name"] == "NRMSE"
49+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import bayesflow as bf
2+
import pytest
3+
4+
5+
def num_variables(x: dict):
6+
return sum(arr.shape[-1] for arr in x.values())
7+
8+
9+
def test_calibration_ecdf(random_estimates, random_targets, var_names):
10+
# basic functionality: automatic variable names
11+
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
12+
assert len(out.axes) == num_variables(random_estimates)
13+
assert out.axes[1].title._text == "beta_1"
14+
15+
# custom variable names
16+
out = bf.diagnostics.plots.calibration_ecdf(
17+
estimates=random_estimates,
18+
targets=random_targets,
19+
variable_names=var_names,
20+
)
21+
assert len(out.axes) == num_variables(random_estimates)
22+
assert out.axes[1].title._text == "$\\beta_1$"
23+
24+
# subset of keys with a single scalar key
25+
out = bf.diagnostics.plots.calibration_ecdf(
26+
estimates=random_estimates, targets=random_targets, variable_keys="sigma"
27+
)
28+
assert len(out.axes) == random_estimates["sigma"].shape[-1]
29+
assert out.axes[0].title._text == "sigma"
30+
31+
# use single array instead of dict of arrays as input
32+
out = bf.diagnostics.plots.calibration_ecdf(
33+
estimates=random_estimates["beta"],
34+
targets=random_targets["beta"],
35+
)
36+
assert len(out.axes) == random_estimates["beta"].shape[-1]
37+
# cannot infer the variable names from an array so default names are used
38+
assert out.axes[1].title._text == "v_1"
39+
40+
41+
def test_calibration_histogram(random_estimates, random_targets):
42+
# basic functionality: automatic variable names
43+
out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets)
44+
assert len(out.axes) == num_variables(random_estimates)
45+
assert out.axes[0].title._text == "beta_0"
46+
47+
48+
def test_recovery(random_estimates, random_targets):
49+
# basic functionality: automatic variable names
50+
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)
51+
assert len(out.axes) == num_variables(random_estimates)
52+
assert out.axes[2].title._text == "sigma"
53+
54+
55+
def test_z_score_contraction(random_estimates, random_targets):
56+
# basic functionality: automatic variable names
57+
out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets)
58+
assert len(out.axes) == num_variables(random_estimates)
59+
assert out.axes[1].title._text == "beta_1"
60+
61+
62+
def test_pairs_samples(random_priors):
63+
out = bf.diagnostics.plots.pairs_samples(
64+
samples=random_priors,
65+
variable_keys=["beta", "sigma"],
66+
)
67+
num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1]
68+
assert out.axes.shape == (num_vars, num_vars)
69+
assert out.axes[0, 0].get_ylabel() == "beta_0"
70+
assert out.axes[2, 2].get_xlabel() == "sigma"
71+
72+
73+
def test_pairs_posterior(random_estimates, random_targets, random_priors):
74+
# basic functionality: automatic variable names
75+
out = bf.diagnostics.plots.pairs_posterior(
76+
random_estimates,
77+
random_targets,
78+
dataset_id=1,
79+
)
80+
num_vars = num_variables(random_estimates)
81+
assert out.axes.shape == (num_vars, num_vars)
82+
assert out.axes[0, 0].get_ylabel() == "beta_0"
83+
assert out.axes[2, 2].get_xlabel() == "sigma"
84+
85+
# also plot priors
86+
out = bf.diagnostics.plots.pairs_posterior(
87+
estimates=random_estimates,
88+
targets=random_targets,
89+
priors=random_priors,
90+
dataset_id=1,
91+
)
92+
num_vars = num_variables(random_estimates)
93+
assert out.axes.shape == (num_vars, num_vars)
94+
assert out.axes[0, 0].get_ylabel() == "beta_0"
95+
assert out.axes[2, 2].get_xlabel() == "sigma"
96+
assert out.figure.legends[0].get_texts()[0]._text == "Prior"
97+
98+
with pytest.raises(ValueError):
99+
bf.diagnostics.plots.pairs_posterior(
100+
estimates=random_estimates,
101+
targets=random_targets,
102+
priors=random_priors,
103+
dataset_id=[1, 3],
104+
)

0 commit comments

Comments
 (0)