11import keras
2+ import pytest
23from tests .utils import assert_models_equal
34
45
@@ -16,3 +17,54 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
1617 loaded = keras .saving .load_model (tmp_path / "model.keras" )
1718
1819 assert_models_equal (approximator , loaded )
20+
21+
22+ def test_save_and_load_all_variants (
23+ tmp_path , adapter , inference_network , summary_network , train_dataset , validation_dataset
24+ ):
25+ """Run the same save/load assertions for all `standardize` options in one test node.
26+
27+ This avoids relying on pytest's bracketed node-id selection; it constructs
28+ an approximator for each `standardize` value and runs the same checks.
29+ Any failures across variants are aggregated and reported at the end.
30+ """
31+ from bayesflow import ContinuousApproximator
32+
33+ standardize_values = [
34+ "all" ,
35+ None ,
36+ "inference_variables" ,
37+ "summary_variables" ,
38+ ("inference_variables" , "summary_variables" , "inference_conditions" ),
39+ ]
40+
41+ failures = []
42+
43+ for standardize in standardize_values :
44+ approximator = ContinuousApproximator (
45+ adapter = adapter ,
46+ inference_network = inference_network ,
47+ summary_network = summary_network ,
48+ standardize = standardize ,
49+ )
50+
51+ try :
52+ data_shapes = keras .tree .map_structure (keras .ops .shape , train_dataset [0 ])
53+ approximator .build (data_shapes )
54+ for layer in approximator .standardize_layers .values ():
55+ assert layer .built
56+ for count in layer .count :
57+ assert count == 0.0
58+ approximator .compute_metrics (** train_dataset [0 ])
59+
60+ model_path = tmp_path / f"model_{ str (standardize )} .keras"
61+ keras .saving .save_model (approximator , model_path )
62+ loaded = keras .saving .load_model (model_path )
63+
64+ assert_models_equal (approximator , loaded )
65+ except Exception as exc : # collect failures and continue
66+ failures .append ((standardize , repr (exc )))
67+
68+ if failures :
69+ msgs = ", " .join ([f"{ s !r} : { m } " for s , m in failures ])
70+ pytest .fail (f"One or more standardize variants failed: { msgs } " )
0 commit comments