Skip to content

Commit 6c97e76

Browse files
authored
Implement comprehensive save/load tests for variants
Added a test to validate saving and loading for all standardize variants of the ContinuousApproximator.
1 parent 12aaeb5 commit 6c97e76

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras
2+
import pytest
23
from tests.utils import assert_models_equal
34

45

@@ -16,3 +17,54 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
1617
loaded = keras.saving.load_model(tmp_path / "model.keras")
1718

1819
assert_models_equal(approximator, loaded)
20+
21+
22+
def test_save_and_load_all_variants(
23+
tmp_path, adapter, inference_network, summary_network, train_dataset, validation_dataset
24+
):
25+
"""Run the same save/load assertions for all `standardize` options in one test node.
26+
27+
This avoids relying on pytest's bracketed node-id selection; it constructs
28+
an approximator for each `standardize` value and runs the same checks.
29+
Any failures across variants are aggregated and reported at the end.
30+
"""
31+
from bayesflow import ContinuousApproximator
32+
33+
standardize_values = [
34+
"all",
35+
None,
36+
"inference_variables",
37+
"summary_variables",
38+
("inference_variables", "summary_variables", "inference_conditions"),
39+
]
40+
41+
failures = []
42+
43+
for standardize in standardize_values:
44+
approximator = ContinuousApproximator(
45+
adapter=adapter,
46+
inference_network=inference_network,
47+
summary_network=summary_network,
48+
standardize=standardize,
49+
)
50+
51+
try:
52+
data_shapes = keras.tree.map_structure(keras.ops.shape, train_dataset[0])
53+
approximator.build(data_shapes)
54+
for layer in approximator.standardize_layers.values():
55+
assert layer.built
56+
for count in layer.count:
57+
assert count == 0.0
58+
approximator.compute_metrics(**train_dataset[0])
59+
60+
model_path = tmp_path / f"model_{str(standardize)}.keras"
61+
keras.saving.save_model(approximator, model_path)
62+
loaded = keras.saving.load_model(model_path)
63+
64+
assert_models_equal(approximator, loaded)
65+
except Exception as exc: # collect failures and continue
66+
failures.append((standardize, repr(exc)))
67+
68+
if failures:
69+
msgs = ", ".join([f"{s!r}: {m}" for s, m in failures])
70+
pytest.fail(f"One or more standardize variants failed: {msgs}")

0 commit comments

Comments
 (0)