Skip to content

Commit bb3e4d1

Browse files
committed
Adding tests
1 parent d21d058 commit bb3e4d1

File tree

4 files changed

+83
-6
lines changed

4 files changed

+83
-6
lines changed

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def test_metric_calibration_error(random_estimates, random_targets, var_names):
3535
assert out["values"].shape == (random_estimates["sigma"].shape[-1],)
3636
assert out["variable_names"] == ["sigma"]
3737

38+
# test quantities
39+
test_quantities = {
40+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
41+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
42+
}
43+
out = bf.diagnostics.metrics.calibration_error(random_estimates, random_targets, test_quantities=test_quantities)
44+
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
3845

3946
def test_posterior_contraction(random_estimates, random_targets):
4047
# basic functionality: automatic variable names
@@ -47,6 +54,16 @@ def test_posterior_contraction(random_estimates, random_targets):
4754
out = bf.diagnostics.metrics.posterior_contraction(random_estimates, random_targets, aggregation=None)
4855
assert out["values"].shape == (random_estimates["sigma"].shape[0], num_variables(random_estimates))
4956

57+
# test quantities
58+
test_quantities = {
59+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
60+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
61+
}
62+
out = bf.diagnostics.metrics.posterior_contraction(
63+
random_estimates, random_targets, test_quantities=test_quantities
64+
)
65+
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
66+
5067

5168
def test_root_mean_squared_error(random_estimates, random_targets):
5269
# basic functionality: automatic variable names
@@ -56,6 +73,16 @@ def test_root_mean_squared_error(random_estimates, random_targets):
5673
assert out["metric_name"] == "NRMSE"
5774
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
5875

76+
# test quantities
77+
test_quantities = {
78+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
79+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
80+
}
81+
out = bf.diagnostics.metrics.root_mean_squared_error(
82+
random_estimates, random_targets, test_quantities=test_quantities
83+
)
84+
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
85+
5986

6087
def test_classifier_two_sample_test(random_samples_a, random_samples_b):
6188
metric = bf.diagnostics.metrics.classifier_two_sample_test(estimates=random_samples_a, targets=random_samples_a)
@@ -95,6 +122,16 @@ def test_calibration_log_gamma(random_estimates, random_targets):
95122
assert out["metric_name"] == "Log Gamma"
96123
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
97124

125+
# test quantities
126+
test_quantities = {
127+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
128+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
129+
}
130+
out = bf.diagnostics.metrics.calibration_log_gamma(
131+
random_estimates, random_targets, test_quantities=test_quantities
132+
)
133+
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
134+
98135

99136
def test_calibration_log_gamma_end_to_end():
100137
# This is a function test for simulation-based calibration.

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def test_calibration_histogram(random_estimates, random_targets):
8585
assert len(out.axes) == num_variables(random_estimates)
8686
assert out.axes[0].title._text == "beta_0"
8787

88+
# test quantities
89+
test_quantities = {
90+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
91+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
92+
}
93+
out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets, test_quantities=test_quantities)
94+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
95+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
96+
assert out.axes[-1].title._text == r"sigma"
97+
8898

8999
def test_loss(history):
90100
out = bf.diagnostics.loss(history)
@@ -102,6 +112,16 @@ def test_recovery_bounds(random_estimates, random_targets):
102112
assert len(out.axes) == num_variables(random_estimates)
103113
assert out.axes[2].title._text == "sigma"
104114

115+
# test quantities
116+
test_quantities = {
117+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
118+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
119+
}
120+
out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets, test_quantities=test_quantities)
121+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
122+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
123+
assert out.axes[-1].title._text == r"sigma"
124+
105125

106126
def test_recovery_symmetric(random_estimates, random_targets):
107127
# basic functionality: automatic variable names
@@ -127,6 +147,16 @@ def test_z_score_contraction(random_estimates, random_targets):
127147
assert len(out.axes) == num_variables(random_estimates)
128148
assert out.axes[1].title._text == "beta_1"
129149

150+
# test quantities
151+
test_quantities = {
152+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
153+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
154+
}
155+
out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets, test_quantities=test_quantities)
156+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
157+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
158+
assert out.axes[-1].title._text == r"sigma"
159+
130160

131161
def test_pairs_samples(random_priors):
132162
out = bf.diagnostics.plots.pairs_samples(
@@ -291,6 +321,16 @@ def test_coverage(random_estimates, random_targets):
291321
assert out.axes[0].get_xlabel() == "Central interval width"
292322
assert out.axes[0].get_ylabel() == "Empirical coverage"
293323

324+
# test quantities
325+
test_quantities = {
326+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
327+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
328+
}
329+
out = bf.diagnostics.plots.coverage(random_estimates, random_targets, test_quantities=test_quantities)
330+
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
331+
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
332+
assert out.axes[-1].title._text == r"sigma"
333+
294334

295335
def test_coverage_diff(random_estimates, random_targets):
296336
# basic functionality: automatic variable names

tests/test_links/test_links.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def check_ordering(output, axis):
2323
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}."
2424
for i in range(output.ndim):
2525
if i != axis % output.ndim:
26-
assert not np.all(np.diff(output, axis=i) > 0), (
27-
f"is ordered along axis which is not meant to be ordered: {i}."
28-
)
26+
assert not np.all(
27+
np.diff(output, axis=i) > 0
28+
), f"is ordered along axis which is not meant to be ordered: {i}."
2929

3030

3131
@pytest.mark.parametrize("axis", [0, 1, 2])

tests/test_networks/test_point_inference_network/test_point_inference_network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def test_save_and_load(tmp_path, point_inference_network, random_samples, random
4444

4545
for key_outer in out1.keys():
4646
for key_inner in out1[key_outer].keys():
47-
assert keras.ops.all(keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])), (
48-
"Output of original and loaded model differs significantly."
49-
)
47+
assert keras.ops.all(
48+
keras.ops.isclose(out1[key_outer][key_inner], out2[key_outer][key_inner])
49+
), "Output of original and loaded model differs significantly."
5050

5151

5252
def test_copy_unequal(point_inference_network, random_samples, random_conditions):

0 commit comments

Comments
 (0)