Skip to content

Commit 3cee913

Browse files
committed
Tests for calibration_ecdf with test_quantities
1 parent 63d743a commit 3cee913

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import bayesflow as bf
2+
import numpy as np
23
import pytest
34

45

@@ -16,6 +17,8 @@ def test_backend():
1617

1718

1819
def test_calibration_ecdf(random_estimates, random_targets, var_names):
20+
print(random_estimates, random_targets, var_names)
21+
1922
# basic functionality: automatic variable names
2023
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
2124
assert len(out.axes) == num_variables(random_estimates)
@@ -46,6 +49,22 @@ def test_calibration_ecdf(random_estimates, random_targets, var_names):
4649
# cannot infer the variable names from an array so default names are used
4750
assert out.axes[1].title._text == "v_1"
4851

52+
# test quantities plots are shown
53+
test_quantities = {
54+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
55+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
56+
}
57+
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets, test_quantities=test_quantities)
58+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
59+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
60+
assert out.axes[-1].title._text == r"sigma"
61+
62+
# test plot titles changed to variable_names in case test quantities exist
63+
out = bf.diagnostics.plots.calibration_ecdf(
64+
random_estimates, random_targets, test_quantities=test_quantities, variable_names=var_names
65+
)
66+
assert out.axes[-1].title._text == r"$\sigma$"
67+
4968

5069
def test_calibration_histogram(random_estimates, random_targets):
5170
# basic functionality: automatic variable names

0 commit comments

Comments
 (0)