@@ -34,11 +34,14 @@ def __init__(self, **kwargs):
3434 self .count = None
3535
3636 def moving_std (self , index : int ) -> Tensor :
37- # return zeros if count is 0
37+ """Calculates the standard deviation from the moving M^2 at the given index and the count.
38+
39+ Important: Where M^2=0, it will return a standard deviation of 1 instead of 0, even if count > 0.
40+ """
3841 return keras .ops .where (
39- self .count > 0 ,
42+ self .moving_m2 [ index ] > 0 ,
4043 keras .ops .sqrt (self .moving_m2 [index ] / self .count ),
41- self . moving_m2 [ index ] * 0 .0 ,
44+ 1 .0 ,
4245 )
4346
4447 def build (self , input_shape : Shape ):
@@ -98,10 +101,8 @@ def call(
98101 self ._update_moments (val , idx )
99102
100103 mean = expand_left_as (self .moving_mean [idx ], val )
104+ # moving_std will return 1 in the case of std=0, so no further checks are necessary here
101105 std = expand_left_as (self .moving_std (idx ), val )
102- # If the std is zero, val - mean(val) = 0, so we can set it to an arbitrary value.
103- # Choosing 1 will leave input unmodified when no training has happened yet.
104- std = keras .ops .where (std == 0.0 , 1.0 , std )
105106
106107 if forward :
107108 out = (val - mean ) / std
0 commit comments