|
| 1 | +import bayesflow as bf |
| 2 | +import pytest |
| 3 | + |
| 4 | + |
| 5 | +def num_variables(x: dict): |
| 6 | + return sum(arr.shape[-1] for arr in x.values()) |
| 7 | + |
| 8 | + |
| 9 | +def test_calibration_ecdf(random_estimates, random_targets, var_names): |
| 10 | + # basic functionality: automatic variable names |
| 11 | + out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets) |
| 12 | + assert len(out.axes) == num_variables(random_estimates) |
| 13 | + assert out.axes[1].title._text == "beta_1" |
| 14 | + |
| 15 | + # custom variable names |
| 16 | + out = bf.diagnostics.plots.calibration_ecdf( |
| 17 | + estimates=random_estimates, |
| 18 | + targets=random_targets, |
| 19 | + variable_names=var_names, |
| 20 | + ) |
| 21 | + assert len(out.axes) == num_variables(random_estimates) |
| 22 | + assert out.axes[1].title._text == "$\\beta_1$" |
| 23 | + |
| 24 | + # subset of keys with a single scalar key |
| 25 | + out = bf.diagnostics.plots.calibration_ecdf( |
| 26 | + estimates=random_estimates, targets=random_targets, variable_keys="sigma" |
| 27 | + ) |
| 28 | + assert len(out.axes) == random_estimates["sigma"].shape[-1] |
| 29 | + assert out.axes[0].title._text == "sigma" |
| 30 | + |
| 31 | + # use single array instead of dict of arrays as input |
| 32 | + out = bf.diagnostics.plots.calibration_ecdf( |
| 33 | + estimates=random_estimates["beta"], |
| 34 | + targets=random_targets["beta"], |
| 35 | + ) |
| 36 | + assert len(out.axes) == random_estimates["beta"].shape[-1] |
| 37 | + # cannot infer the variable names from an array so default names are used |
| 38 | + assert out.axes[1].title._text == "v_1" |
| 39 | + |
| 40 | + |
| 41 | +def test_calibration_histogram(random_estimates, random_targets): |
| 42 | + # basic functionality: automatic variable names |
| 43 | + out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets) |
| 44 | + assert len(out.axes) == num_variables(random_estimates) |
| 45 | + assert out.axes[0].title._text == "beta_0" |
| 46 | + |
| 47 | + |
| 48 | +def test_recovery(random_estimates, random_targets): |
| 49 | + # basic functionality: automatic variable names |
| 50 | + out = bf.diagnostics.plots.recovery(random_estimates, random_targets) |
| 51 | + assert len(out.axes) == num_variables(random_estimates) |
| 52 | + assert out.axes[2].title._text == "sigma" |
| 53 | + |
| 54 | + |
| 55 | +def test_z_score_contraction(random_estimates, random_targets): |
| 56 | + # basic functionality: automatic variable names |
| 57 | + out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets) |
| 58 | + assert len(out.axes) == num_variables(random_estimates) |
| 59 | + assert out.axes[1].title._text == "beta_1" |
| 60 | + |
| 61 | + |
| 62 | +def test_pairs_samples(random_priors): |
| 63 | + out = bf.diagnostics.plots.pairs_samples( |
| 64 | + samples=random_priors, |
| 65 | + variable_keys=["beta", "sigma"], |
| 66 | + ) |
| 67 | + num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1] |
| 68 | + assert out.axes.shape == (num_vars, num_vars) |
| 69 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 70 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 71 | + |
| 72 | + |
| 73 | +def test_pairs_posterior(random_estimates, random_targets, random_priors): |
| 74 | + # basic functionality: automatic variable names |
| 75 | + out = bf.diagnostics.plots.pairs_posterior( |
| 76 | + random_estimates, |
| 77 | + random_targets, |
| 78 | + dataset_id=1, |
| 79 | + ) |
| 80 | + num_vars = num_variables(random_estimates) |
| 81 | + assert out.axes.shape == (num_vars, num_vars) |
| 82 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 83 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 84 | + |
| 85 | + # also plot priors |
| 86 | + out = bf.diagnostics.plots.pairs_posterior( |
| 87 | + estimates=random_estimates, |
| 88 | + targets=random_targets, |
| 89 | + priors=random_priors, |
| 90 | + dataset_id=1, |
| 91 | + ) |
| 92 | + num_vars = num_variables(random_estimates) |
| 93 | + assert out.axes.shape == (num_vars, num_vars) |
| 94 | + assert out.axes[0, 0].get_ylabel() == "beta_0" |
| 95 | + assert out.axes[2, 2].get_xlabel() == "sigma" |
| 96 | + assert out.figure.legends[0].get_texts()[0]._text == "Prior" |
| 97 | + |
| 98 | + with pytest.raises(ValueError): |
| 99 | + bf.diagnostics.plots.pairs_posterior( |
| 100 | + estimates=random_estimates, |
| 101 | + targets=random_targets, |
| 102 | + priors=random_priors, |
| 103 | + dataset_id=[1, 3], |
| 104 | + ) |
0 commit comments