|
1 | 1 | import bayesflow as bf |
| 2 | +import pytest |
2 | 3 |
|
3 | 4 |
|
4 | 5 | def num_variables(x: dict): |
@@ -56,3 +57,46 @@ def test_z_score_contraction(random_estimates, random_targets): |
56 | 57 | out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets) |
57 | 58 | assert len(out.axes) == num_variables(random_estimates) |
58 | 59 | assert out.axes[1].title._text == "beta_1" |
| 60 | + |
| 61 | +def test_pairs_samples(random_priors): |
| 62 | + out = bf.diagnostics.plots.pairs_samples( |
| 63 | + samples = random_priors, |
| 64 | + variable_keys = ["beta", "sigma"], |
| 65 | + ) |
| 66 | + num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1] |
| 67 | + assert out.axes.shape == (num_vars, num_vars) |
| 68 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 69 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 70 | + |
| 71 | +def test_pairs_posterior(random_estimates, random_targets, random_priors): |
| 72 | + # basic functionality: automatic variable names |
| 73 | + out = bf.diagnostics.plots.pairs_posterior( |
| 74 | + random_estimates, |
| 75 | + random_targets, |
| 76 | + dataset_id=1, |
| 77 | + ) |
| 78 | + num_vars = num_variables(random_estimates) |
| 79 | + assert out.axes.shape == (num_vars, num_vars) |
| 80 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 81 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 82 | + |
| 83 | + # also plot priors |
| 84 | + out = bf.diagnostics.plots.pairs_posterior( |
| 85 | + estimates=random_estimates, |
| 86 | + targets=random_targets, |
| 87 | + priors=random_priors, |
| 88 | + dataset_id=1, |
| 89 | + ) |
| 90 | + num_vars = num_variables(random_estimates) |
| 91 | + assert out.axes.shape == (num_vars, num_vars) |
| 92 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 93 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 94 | + assert out.figure.legends[0].get_texts()[0]._text == "Prior" |
| 95 | + |
| 96 | + with pytest.raises(ValueError): |
| 97 | + bf.diagnostics.plots.pairs_posterior( |
| 98 | + estimates=random_estimates, |
| 99 | + targets=random_targets, |
| 100 | + priors=random_priors, |
| 101 | + dataset_id=[1, 3], |
| 102 | + ) |
0 commit comments