Skip to content

Commit c4cc133

Browse files
committed
standardization: add test for multi-input values (failing)
This test reveals to bugs in the standarization layer: - count is updated multiple times - batch_count is too small, as the sizes from reduce_axes have to be multiplied
1 parent 057f3fd commit c4cc133

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/test_networks/test_standardization.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@ def test_nested_consistency_forward_inverse():
9191
np.testing.assert_allclose(random_input["b"], recovered["b"], atol=1e-4)
9292

9393

94+
def test_nested_accuracy_forward():
95+
from bayesflow.utils import tree_concatenate
96+
97+
# create inputs for two training passes
98+
random_input_a_1 = keras.random.normal((2, 3, 5))
99+
random_input_b_1 = keras.random.normal((4, 3))
100+
random_input_1 = {"a": random_input_a_1, "b": random_input_b_1}
101+
102+
random_input_a_2 = keras.random.normal((3, 3, 5))
103+
random_input_b_2 = keras.random.normal((3, 3))
104+
random_input_2 = {"a": random_input_a_2, "b": random_input_b_2}
105+
106+
# complete data for testing mean and std are 0 and 1
107+
random_input = tree_concatenate([random_input_1, random_input_2], axis=0)
108+
109+
layer = Standardization()
110+
111+
_ = layer(random_input_1, stage="training", forward=True)
112+
_ = layer(random_input_2, stage="training", forward=True)
113+
114+
standardized = layer(random_input, stage="inference", forward=True)
115+
standardized = keras.tree.map_structure(keras.ops.convert_to_numpy, standardized)
116+
117+
np.testing.assert_allclose(
118+
np.mean(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 0.0, atol=1e-4
119+
)
120+
np.testing.assert_allclose(
121+
np.mean(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 0.0, atol=1e-4
122+
)
123+
np.testing.assert_allclose(np.std(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 1.0, atol=1e-4)
124+
np.testing.assert_allclose(np.std(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 1.0, atol=1e-4)
125+
126+
94127
def test_transformation_type_both_sides_scale():
95128
# Fix a known covariance and mean in original (not standardized space)
96129
covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")

0 commit comments

Comments
 (0)