Skip to content

Commit df1761b

Browse files
committed
adapt handling of the special case M^2=0
1 parent c6d79ae commit df1761b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ def __init__(self, **kwargs):
3434
self.count = None
3535

3636
def moving_std(self, index: int) -> Tensor:
37-
# return zeros if count is 0
37+
"""Calculates the standard deviation from the moving M^2 at the given index and the count.
38+
39+
Important: Where M^2=0, it will return a standard deviation of 1 instead of 0, even if count > 0.
40+
"""
3841
return keras.ops.where(
39-
self.count > 0,
42+
self.moving_m2[index] > 0,
4043
keras.ops.sqrt(self.moving_m2[index] / self.count),
41-
self.moving_m2[index] * 0.0,
44+
1.0,
4245
)
4346

4447
def build(self, input_shape: Shape):
@@ -98,10 +101,8 @@ def call(
98101
self._update_moments(val, idx)
99102

100103
mean = expand_left_as(self.moving_mean[idx], val)
104+
# moving_std will return 1 in the case of std=0, so no further checks are necessary here
101105
std = expand_left_as(self.moving_std(idx), val)
102-
# If the std is zero, val - mean(val) = 0, so we can set it to an arbitrary value.
103-
# Choosing 1 will leave input unmodified when no training has happened yet.
104-
std = keras.ops.where(std == 0.0, 1.0, std)
105106

106107
if forward:
107108
out = (val - mean) / std

0 commit comments

Comments
 (0)