Skip to content

Commit ed5f9c3

Browse files
committed
Bring updated approximator_standardization test from upstream
1 parent 628227a commit ed5f9c3

File tree

1 file changed

+0
-52
lines changed

1 file changed

+0
-52
lines changed
Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import keras
2-
import pytest
32
from tests.utils import assert_models_equal
43

54

@@ -17,54 +16,3 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
1716
loaded = keras.saving.load_model(tmp_path / "model.keras")
1817

1918
assert_models_equal(approximator, loaded)
20-
21-
22-
def test_save_and_load_all_variants(
23-
tmp_path, adapter, inference_network, summary_network, train_dataset, validation_dataset
24-
):
25-
"""Run the same save/load assertions for all `standardize` options in one test node.
26-
27-
This avoids relying on pytest's bracketed node-id selection; it constructs
28-
an approximator for each `standardize` value and runs the same checks.
29-
Any failures across variants are aggregated and reported at the end.
30-
"""
31-
from bayesflow import ContinuousApproximator
32-
33-
standardize_values = [
34-
"all",
35-
None,
36-
"inference_variables",
37-
"summary_variables",
38-
("inference_variables", "summary_variables", "inference_conditions"),
39-
]
40-
41-
failures = []
42-
43-
for standardize in standardize_values:
44-
approximator = ContinuousApproximator(
45-
adapter=adapter,
46-
inference_network=inference_network,
47-
summary_network=summary_network,
48-
standardize=standardize,
49-
)
50-
51-
try:
52-
data_shapes = keras.tree.map_structure(keras.ops.shape, train_dataset[0])
53-
approximator.build(data_shapes)
54-
for layer in approximator.standardize_layers.values():
55-
assert layer.built
56-
for count in layer.count:
57-
assert count == 0.0
58-
approximator.compute_metrics(**train_dataset[0])
59-
60-
model_path = tmp_path / f"model_{str(standardize)}.keras"
61-
keras.saving.save_model(approximator, model_path)
62-
loaded = keras.saving.load_model(model_path)
63-
64-
assert_models_equal(approximator, loaded)
65-
except Exception as exc: # collect failures and continue
66-
failures.append((standardize, repr(exc)))
67-
68-
if failures:
69-
msgs = ", ".join([f"{s!r}: {m}" for s, m in failures])
70-
pytest.fail(f"One or more standardize variants failed: {msgs}")

0 commit comments

Comments
 (0)