Skip to content

Commit 0b32998

Browse files
committed
breaking: fix bugs regarding count in standardization layer
Fixes #524 This fixes the two bugs described in c4cc133: - count was accidentally updated, leading to wrong values - count was calculated wrongly, as only the batch size was used. Correct is the product of all reduce dimensions. This lead to wrong standard deviations While the batch dimension is the same for all inputs, the size of the second dimension might vary. For this reason, we need to introduce an input-specific `count` variable. This breaks serialization.
1 parent c4cc133 commit 0b32998

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def moving_std(self, index: int) -> Tensor:
4040
"""
4141
return keras.ops.where(
4242
self.moving_m2[index] > 0,
43-
keras.ops.sqrt(self.moving_m2[index] / self.count),
43+
keras.ops.sqrt(self.moving_m2[index] / self.count[index]),
4444
1.0,
4545
)
4646

@@ -53,7 +53,7 @@ def build(self, input_shape: Shape):
5353
self.moving_m2 = [
5454
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
5555
]
56-
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
56+
self.count = [self.add_weight(shape=(), initializer="zeros", trainable=False) for _ in flattened_shapes]
5757

5858
def call(
5959
self,
@@ -150,7 +150,7 @@ def _update_moments(self, x: Tensor, index: int):
150150
"""
151151

152152
reduce_axes = tuple(range(x.ndim - 1))
153-
batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype)
153+
batch_count = keras.ops.cast(keras.ops.prod(keras.ops.shape(x)[:-1]), self.count[index].dtype)
154154

155155
# Compute batch mean and M2 per feature
156156
batch_mean = keras.ops.mean(x, axis=reduce_axes)
@@ -159,7 +159,7 @@ def _update_moments(self, x: Tensor, index: int):
159159
# Read current totals
160160
mean = self.moving_mean[index]
161161
m2 = self.moving_m2[index]
162-
count = self.count
162+
count = self.count[index]
163163

164164
total_count = count + batch_count
165165
delta = batch_mean - mean
@@ -169,4 +169,4 @@ def _update_moments(self, x: Tensor, index: int):
169169

170170
self.moving_mean[index].assign(new_mean)
171171
self.moving_m2[index].assign(new_m2)
172-
self.count.assign(total_count)
172+
self.count[index].assign(total_count)

tests/test_approximators/test_approximator_standardization/test_approximator_standardization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import keras
22
from tests.utils import assert_models_equal
3+
import numpy as np
34

45

56
def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset):
@@ -8,7 +9,7 @@ def test_save_and_load(tmp_path, approximator, train_dataset, validation_dataset
89
approximator.build(data_shapes)
910
for layer in approximator.standardize_layers.values():
1011
assert layer.built
11-
assert layer.count == 0
12+
np.testing.assert_allclose([c.value.numpy() for c in layer.count], 0.0)
1213
approximator.compute_metrics(**train_dataset[0])
1314

1415
keras.saving.save_model(approximator, tmp_path / "model.keras")

0 commit comments

Comments
 (0)