diff --git a/bayesflow/networks/standardization/standardization.py b/bayesflow/networks/standardization/standardization.py index a30a9a1e9..9dfdf2cb2 100644 --- a/bayesflow/networks/standardization/standardization.py +++ b/bayesflow/networks/standardization/standardization.py @@ -40,7 +40,7 @@ def moving_std(self, index: int) -> Tensor: """ return keras.ops.where( self.moving_m2[index] > 0, - keras.ops.sqrt(self.moving_m2[index] / self.count), + keras.ops.sqrt(self.moving_m2[index] / self.count[index]), 1.0, ) @@ -53,7 +53,7 @@ def build(self, input_shape: Shape): self.moving_m2 = [ self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes ] - self.count = self.add_weight(shape=(), initializer="zeros", trainable=False) + self.count = [self.add_weight(shape=(), initializer="zeros", trainable=False) for _ in flattened_shapes] def call( self, @@ -150,7 +150,7 @@ def _update_moments(self, x: Tensor, index: int): """ reduce_axes = tuple(range(x.ndim - 1)) - batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype) + batch_count = keras.ops.cast(keras.ops.prod(keras.ops.shape(x)[:-1]), self.count[index].dtype) # Compute batch mean and M2 per feature batch_mean = keras.ops.mean(x, axis=reduce_axes) @@ -159,7 +159,7 @@ def _update_moments(self, x: Tensor, index: int): # Read current totals mean = self.moving_mean[index] m2 = self.moving_m2[index] - count = self.count + count = self.count[index] total_count = count + batch_count delta = batch_mean - mean @@ -169,4 +169,4 @@ def _update_moments(self, x: Tensor, index: int): self.moving_mean[index].assign(new_mean) self.moving_m2[index].assign(new_m2) - self.count.assign(total_count) + self.count[index].assign(total_count) diff --git a/tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py b/tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py index 9c73d4717..7cf0f6aba 100644 --- a/tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py +++ b/tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py @@ -8,7 +8,8 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset approximator.build(data_shapes) for layer in approximator.standardize_layers.values(): assert layer.built - assert layer.count == 0 + for count in layer.count: + assert count == 0.0 approximator.compute_metrics(**train_dataset[0]) keras.saving.save_model(approximator, tmp_path / "model.keras") diff --git a/tests/test_approximators/test_build.py b/tests/test_approximators/test_build.py index 5947783dd..bea897115 100644 --- a/tests/test_approximators/test_build.py +++ b/tests/test_approximators/test_build.py @@ -14,4 +14,5 @@ def test_build(approximator, simulator, batch_size, adapter): approximator.build(batch_shapes) for layer in approximator.standardize_layers.values(): assert layer.built - assert layer.count == 0 + for count in layer.count: + assert count == 0.0 diff --git a/tests/test_networks/test_standardization.py b/tests/test_networks/test_standardization.py index 8b83de498..86881a384 100644 --- a/tests/test_networks/test_standardization.py +++ b/tests/test_networks/test_standardization.py @@ -91,6 +91,39 @@ def test_nested_consistency_forward_inverse(): np.testing.assert_allclose(random_input["b"], recovered["b"], atol=1e-4) +def test_nested_accuracy_forward(): + from bayesflow.utils import tree_concatenate + + # create inputs for two training passes + random_input_a_1 = keras.random.normal((2, 3, 5)) + random_input_b_1 = keras.random.normal((4, 3)) + random_input_1 = {"a": random_input_a_1, "b": random_input_b_1} + + random_input_a_2 = keras.random.normal((3, 3, 5)) + random_input_b_2 = keras.random.normal((3, 3)) + random_input_2 = {"a": random_input_a_2, "b": random_input_b_2} + + # complete data for testing mean and std are 0 and 1 + random_input = tree_concatenate([random_input_1, random_input_2], axis=0) + + layer = Standardization() + + _ = layer(random_input_1, stage="training", forward=True) + _ = layer(random_input_2, stage="training", forward=True) + + standardized = layer(random_input, stage="inference", forward=True) + standardized = keras.tree.map_structure(keras.ops.convert_to_numpy, standardized) + + np.testing.assert_allclose( + np.mean(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 0.0, atol=1e-4 + ) + np.testing.assert_allclose( + np.mean(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 0.0, atol=1e-4 + ) + np.testing.assert_allclose(np.std(standardized["a"], axis=tuple(range(standardized["a"].ndim - 1))), 1.0, atol=1e-4) + np.testing.assert_allclose(np.std(standardized["b"], axis=tuple(range(standardized["b"].ndim - 1))), 1.0, atol=1e-4) + + def test_transformation_type_both_sides_scale(): # Fix a known covariance and mean in original (not standardized space) covariance = np.array([[1, 0.5], [0.5, 2.0]], dtype="float32")