33import keras
44
55from bayesflow .types import Tensor , Shape
6- from bayesflow .utils .serialization import serialize , deserialize , serializable
6+ from bayesflow .utils .serialization import serializable
77from bayesflow .utils import expand_left_as , layer_kwargs
88from bayesflow .utils .tree import flatten_shape
99
1010
1111@serializable ("bayesflow.networks" )
1212class Standardization (keras .Layer ):
13- def __init__ (self , momentum : float = 0.95 , epsilon : float = 1e-6 , ** kwargs ):
13+ def __init__ (self , ** kwargs ):
1414 """
1515 Initializes a Standardization layer that will keep track of the running mean and
1616 running standard deviation across a batch of potentially nested tensors.
1717
18+ The layer computes and stores running estimates of the mean and variance using a numerically
19+ stable online algorithm, allowing for consistent normalization during both training and inference,
20+ regardless of batch composition.
21+
1822 Parameters
1923 ----------
20- momentum : float, optional
21- Momentum for the exponential moving average used to update the mean and
22- standard deviation during training. Must be between 0 and 1.
23- Default is 0.95.
24- epsilon: float, optional
25- Stability parameter to avoid division by zero.
24+ **kwargs
25+ Additional keyword arguments passed to the base Keras Layer.
26+
27+ Notes
28+ -----
2629 """
2730 super ().__init__ (** layer_kwargs (kwargs ))
2831
29- self .momentum = momentum
30- self .epsilon = epsilon
3132 self .moving_mean = None
32- self .moving_std = None
33+ self .moving_m2 = None
34+ self .count = None
35+
36+ def moving_std (self , index : int ) -> Tensor :
37+ return keras .ops .sqrt (self .moving_m2 [index ] / self .count )
3338
3439 def build (self , input_shape : Shape ):
3540 flattened_shapes = flatten_shape (input_shape )
41+
3642 self .moving_mean = [
3743 self .add_weight (shape = (shape [- 1 ],), initializer = "zeros" , trainable = False ) for shape in flattened_shapes
3844 ]
39- self .moving_std = [
40- self .add_weight (shape = (shape [- 1 ],), initializer = "ones " , trainable = False ) for shape in flattened_shapes
45+ self .moving_m2 = [
46+ self .add_weight (shape = (shape [- 1 ],), initializer = "zeros " , trainable = False ) for shape in flattened_shapes
4147 ]
42-
43- def get_config (self ) -> dict :
44- base_config = super ().get_config ()
45- config = {"momentum" : self .momentum , "epsilon" : self .epsilon }
46- return base_config | serialize (config )
47-
48- @classmethod
49- def from_config (cls , config , custom_objects = None ):
50- return cls (** deserialize (config , custom_objects = custom_objects ))
48+ self .count = self .add_weight (shape = (), initializer = "zeros" , trainable = False )
5149
5250 def call (
5351 self ,
@@ -80,23 +78,25 @@ def call(
8078 flattened = keras .tree .flatten (x )
8179 outputs , log_det_jacs = [], []
8280
83- for i , val in enumerate (flattened ):
81+ for idx , val in enumerate (flattened ):
8482 if stage == "training" :
85- self ._update_moments (val , i )
83+ self ._update_moments (val , idx )
8684
87- mean = expand_left_as (self .moving_mean [i ], val )
88- std = expand_left_as (self .moving_std [ i ] , val )
85+ mean = expand_left_as (self .moving_mean [idx ], val )
86+ std = expand_left_as (self .moving_std ( idx ) , val )
8987
9088 if forward :
9189 out = (val - mean ) / std
90+ # if the std is zero, out will become nan. As val - mean(val) = 0 if std(val) = 0,
91+ # we can just replace them with zeros.
92+ out = keras .ops .nan_to_num (out , nan = 0.0 )
9293 else :
9394 out = mean + std * val
9495
9596 outputs .append (out )
9697
9798 if log_det_jac :
9899 ldj = keras .ops .sum (keras .ops .log (keras .ops .abs (std )), axis = - 1 )
99- # For convenience, tile to batch shape of val
100100 ldj = keras .ops .tile (ldj , keras .ops .shape (val )[:- 1 ])
101101 log_det_jacs .append (- ldj if forward else ldj )
102102
@@ -108,9 +108,38 @@ def call(
108108 return outputs
109109
110110 def _update_moments (self , x : Tensor , index : int ):
111- mean = keras .ops .mean (x , axis = tuple (range (keras .ops .ndim (x ) - 1 )))
112- std = keras .ops .std (x , axis = tuple (range (keras .ops .ndim (x ) - 1 )))
113- std = keras .ops .maximum (std , self .epsilon )
111+ """
112+ Incrementally updates the running mean and variance (M2) per feature using a numerically
113+ stable online algorithm.
114+
115+ Parameters
116+ ----------
117+ x : Tensor
118+ Input tensor of shape (..., features), where all axes except the last are treated as batch/sample axes.
119+ The method computes batch-wise statistics by aggregating over all non-feature axes and updates the
120+ running totals (mean, M2, and sample count) accordingly.
121+ index : int
122+ The index of the corresponding running statistics to be updated.
123+ """
124+
125+ reduce_axes = tuple (range (x .ndim - 1 ))
126+ batch_count = keras .ops .cast (keras .ops .shape (x )[0 ], self .count .dtype )
127+
128+ # Compute batch mean and M2 per feature
129+ batch_mean = keras .ops .mean (x , axis = reduce_axes )
130+ batch_m2 = keras .ops .sum ((x - expand_left_as (batch_mean , x )) ** 2 , axis = reduce_axes )
131+
132+ # Read current totals
133+ mean = self .moving_mean [index ]
134+ m2 = self .moving_m2 [index ]
135+ count = self .count
136+
137+ total_count = count + batch_count
138+ delta = batch_mean - mean
139+
140+ new_mean = mean + delta * (batch_count / total_count )
141+ new_m2 = m2 + batch_m2 + (delta ** 2 ) * (count * batch_count / total_count )
114142
115- self .moving_mean [index ].assign (self .momentum * self .moving_mean [index ] + (1.0 - self .momentum ) * mean )
116- self .moving_std [index ].assign (self .momentum * self .moving_std [index ] + (1.0 - self .momentum ) * std )
143+ self .moving_mean [index ].assign (new_mean )
144+ self .moving_m2 [index ].assign (new_m2 )
145+ self .count .assign (total_count )
0 commit comments