11import keras
2- import pytest
32from tests .utils import assert_models_equal
43
54
@@ -17,54 +16,3 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
1716 loaded = keras .saving .load_model (tmp_path / "model.keras" )
1817
1918 assert_models_equal (approximator , loaded )
20-
21-
22- def test_save_and_load_all_variants (
23- tmp_path , adapter , inference_network , summary_network , train_dataset , validation_dataset
24- ):
25- """Run the same save/load assertions for all `standardize` options in one test node.
26-
27- This avoids relying on pytest's bracketed node-id selection; it constructs
28- an approximator for each `standardize` value and runs the same checks.
29- Any failures across variants are aggregated and reported at the end.
30- """
31- from bayesflow import ContinuousApproximator
32-
33- standardize_values = [
34- "all" ,
35- None ,
36- "inference_variables" ,
37- "summary_variables" ,
38- ("inference_variables" , "summary_variables" , "inference_conditions" ),
39- ]
40-
41- failures = []
42-
43- for standardize in standardize_values :
44- approximator = ContinuousApproximator (
45- adapter = adapter ,
46- inference_network = inference_network ,
47- summary_network = summary_network ,
48- standardize = standardize ,
49- )
50-
51- try :
52- data_shapes = keras .tree .map_structure (keras .ops .shape , train_dataset [0 ])
53- approximator .build (data_shapes )
54- for layer in approximator .standardize_layers .values ():
55- assert layer .built
56- for count in layer .count :
57- assert count == 0.0
58- approximator .compute_metrics (** train_dataset [0 ])
59-
60- model_path = tmp_path / f"model_{ str (standardize )} .keras"
61- keras .saving .save_model (approximator , model_path )
62- loaded = keras .saving .load_model (model_path )
63-
64- assert_models_equal (approximator , loaded )
65- except Exception as exc : # collect failures and continue
66- failures .append ((standardize , repr (exc )))
67-
68- if failures :
69- msgs = ", " .join ([f"{ s !r} : { m } " for s , m in failures ])
70- pytest .fail (f"One or more standardize variants failed: { msgs } " )
0 commit comments