Skip to content

Commit 96589ce

Browse files
committed
add more diagnostics tests
1 parent 65d251f commit 96589ce

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bayesflow as bf
2+
import pytest
23

34

45
def num_variables(x: dict):
@@ -56,3 +57,46 @@ def test_z_score_contraction(random_estimates, random_targets):
5657
out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets)
5758
assert len(out.axes) == num_variables(random_estimates)
5859
assert out.axes[1].title._text == "beta_1"
60+
61+
def test_pairs_samples(random_priors):
62+
out = bf.diagnostics.plots.pairs_samples(
63+
samples = random_priors,
64+
variable_keys = ["beta", "sigma"],
65+
)
66+
num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1]
67+
assert out.axes.shape == (num_vars, num_vars)
68+
assert out.axes[0, 0].get_ylabel() == "beta_0"
69+
assert out.axes[2, 2].get_xlabel() == "sigma"
70+
71+
def test_pairs_posterior(random_estimates, random_targets, random_priors):
72+
# basic functionality: automatic variable names
73+
out = bf.diagnostics.plots.pairs_posterior(
74+
random_estimates,
75+
random_targets,
76+
dataset_id=1,
77+
)
78+
num_vars = num_variables(random_estimates)
79+
assert out.axes.shape == (num_vars, num_vars)
80+
assert out.axes[0, 0].get_ylabel() == "beta_0"
81+
assert out.axes[2, 2].get_xlabel() == "sigma"
82+
83+
# also plot priors
84+
out = bf.diagnostics.plots.pairs_posterior(
85+
estimates=random_estimates,
86+
targets=random_targets,
87+
priors=random_priors,
88+
dataset_id=1,
89+
)
90+
num_vars = num_variables(random_estimates)
91+
assert out.axes.shape == (num_vars, num_vars)
92+
assert out.axes[0, 0].get_ylabel() == "beta_0"
93+
assert out.axes[2, 2].get_xlabel() == "sigma"
94+
assert out.figure.legends[0].get_texts()[0]._text == "Prior"
95+
96+
with pytest.raises(ValueError):
97+
bf.diagnostics.plots.pairs_posterior(
98+
estimates=random_estimates,
99+
targets=random_targets,
100+
priors=random_priors,
101+
dataset_id=[1, 3],
102+
)

0 commit comments

Comments
 (0)