Skip to content

Commit 1b2b5be

Browse files
committed
Adapt and fix tests
1 parent bde587c commit 1b2b5be

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def build(self, input_shape: Shape):
4545
self.moving_m2 = [
4646
self.add_weight(shape=(shape[-1],), initializer="ones", trainable=False) for shape in flattened_shapes
4747
]
48-
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False, dtype="int64")
48+
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
4949

5050
def call(
5151
self,
@@ -79,9 +79,6 @@ def call(
7979
outputs, log_det_jacs = [], []
8080

8181
for idx, val in enumerate(flattened):
82-
if stage == "training":
83-
self._update_moments(val, idx)
84-
8582
mean = expand_left_as(self.moving_mean[idx], val)
8683
std = expand_left_as(self.moving_std(idx), val)
8784

@@ -94,10 +91,12 @@ def call(
9491

9592
if log_det_jac:
9693
ldj = keras.ops.sum(keras.ops.log(keras.ops.abs(std)), axis=-1)
97-
# For convenience, tile to batch shape of val
9894
ldj = keras.ops.tile(ldj, keras.ops.shape(val)[:-1])
9995
log_det_jacs.append(-ldj if forward else ldj)
10096

97+
if stage == "training":
98+
self._update_moments(val, idx)
99+
101100
outputs = keras.tree.pack_sequence_as(x, outputs)
102101
if log_det_jac:
103102
log_det_jacs = keras.tree.pack_sequence_as(x, log_det_jacs)

tests/test_networks/test_standardization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,16 @@
1010
def test_forward_standardization_training():
1111
random_input = keras.random.normal((8, 4))
1212

13-
layer = Standardization(momentum=0.0) # no EMA for test stability
13+
layer = Standardization()
1414
layer.build(random_input.shape)
1515

1616
out = layer(random_input, stage="training")
1717

1818
moving_mean = keras.ops.convert_to_numpy(layer.moving_mean[0])
19-
moving_std = keras.ops.convert_to_numpy(layer.moving_std[0])
2019
random_input = keras.ops.convert_to_numpy(random_input)
2120
out = keras.ops.convert_to_numpy(out)
2221

2322
np.testing.assert_allclose(moving_mean, np.mean(random_input, axis=0), atol=1e-5)
24-
np.testing.assert_allclose(moving_std, np.std(random_input, axis=0), atol=1e-5)
2523

2624
assert out.shape == random_input.shape
2725
assert not np.any(np.isnan(out))
@@ -42,9 +40,10 @@ def test_inverse_standardization_ldj():
4240

4341
def test_consistency_forward_inverse():
4442
random_input = keras.random.normal((4, 20, 5))
45-
layer = Standardization(momentum=0.0)
46-
layer.build((5,))
47-
standardized = layer(random_input, stage="training", forward=True)
43+
layer = Standardization()
44+
_ = layer(random_input, stage="training", forward=True)
45+
46+
standardized = layer(random_input, stage="inference", forward=True)
4847
recovered = layer(standardized, stage="inference", forward=False)
4948

5049
random_input = keras.ops.convert_to_numpy(random_input)
@@ -58,9 +57,10 @@ def test_nested_consistency_forward_inverse():
5857
random_input_b = keras.random.normal((4, 3))
5958
random_input = {"a": random_input_a, "b": random_input_b}
6059

61-
layer = Standardization(momentum=0.0)
60+
layer = Standardization()
6261

63-
standardized = layer(random_input, stage="training", forward=True)
62+
_ = layer(random_input, stage="training", forward=True)
63+
standardized = layer(random_input, stage="inference", forward=True)
6464
recovered = layer(standardized, stage="inference", forward=False)
6565

6666
random_input = keras.tree.map_structure(keras.ops.convert_to_numpy, random_input)

0 commit comments

Comments
 (0)