Skip to content

Commit 5c529a2

Browse files
committed
[no ci] reformulate zero std case
1 parent 0952a29 commit 5c529a2

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def __init__(self, **kwargs):
3434
self.count = None
3535

3636
def moving_std(self, index: int) -> Tensor:
37-
return keras.ops.sqrt(self.moving_m2[index] / self.count)
37+
# return zeros if count is 0
38+
return keras.ops.where(
39+
self.count > 0,
40+
keras.ops.sqrt(self.moving_m2[index] / self.count),
41+
self.moving_m2[index] * 0.0,
42+
)
3843

3944
def build(self, input_shape: Shape):
4045
flattened_shapes = flatten_shape(input_shape)
@@ -93,13 +98,12 @@ def call(
9398

9499
mean = expand_left_as(self.moving_mean[idx], val)
95100
std = expand_left_as(self.moving_std(idx), val)
96-
std = keras.ops.where(self.count > 0, std, 1.0)
101+
# If the std is zero, val - mean(val) = 0, so we can set it to an arbitrary value.
102+
# Choosing 1 will leave input unmodified when no training has happened yet.
103+
std = keras.ops.where(std == 0.0, 1.0, std)
97104

98105
if forward:
99106
out = (val - mean) / std
100-
# if the std is zero, out will become nan or inf. As val - mean(val) = 0 if std(val) = 0,
101-
# we can just replace them with zeros.
102-
out = keras.ops.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
103107
else:
104108
match transformation_type:
105109
case "rank1+shift":

0 commit comments

Comments
 (0)