@@ -34,7 +34,12 @@ def __init__(self, **kwargs):
3434 self .count = None
3535
3636 def moving_std (self , index : int ) -> Tensor :
37- return keras .ops .sqrt (self .moving_m2 [index ] / self .count )
37+ # return zeros if count is 0
38+ return keras .ops .where (
39+ self .count > 0 ,
40+ keras .ops .sqrt (self .moving_m2 [index ] / self .count ),
41+ self .moving_m2 [index ] * 0.0 ,
42+ )
3843
3944 def build (self , input_shape : Shape ):
4045 flattened_shapes = flatten_shape (input_shape )
@@ -93,13 +98,12 @@ def call(
9398
9499 mean = expand_left_as (self .moving_mean [idx ], val )
95100 std = expand_left_as (self .moving_std (idx ), val )
96- std = keras .ops .where (self .count > 0 , std , 1.0 )
101+ # If the std is zero, val - mean(val) = 0, so we can set it to an arbitrary value.
102+ # Choosing 1 will leave input unmodified when no training has happened yet.
103+ std = keras .ops .where (std == 0.0 , 1.0 , std )
97104
98105 if forward :
99106 out = (val - mean ) / std
100- # if the std is zero, out will become nan or inf. As val - mean(val) = 0 if std(val) = 0,
101- # we can just replace them with zeros.
102- out = keras .ops .nan_to_num (out , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
103107 else :
104108 match transformation_type :
105109 case "rank1+shift" :
0 commit comments