@@ -130,16 +130,18 @@ def update(self, x):
130130 x = bm .as_jax (x )
131131
132132 if share .load ('fit' ):
133- mean = jnp .mean (x , self .axis )
134- mean_of_square = jnp .mean (_square (x ), self .axis )
135- if self .axis_name is not None :
136- mean , mean_of_square = jnp .split (lax .pmean (jnp .concatenate ([mean , mean_of_square ]),
137- axis_name = self .axis_name ,
138- axis_index_groups = self .axis_index_groups ),
139- 2 )
140- var = jnp .maximum (0. , mean_of_square - _square (mean ))
141- self .running_mean .value = (self .momentum * self .running_mean + (1 - self .momentum ) * mean )
142- self .running_var .value = (self .momentum * self .running_var + (1 - self .momentum ) * var )
133+ mean = jnp .mean (x , self .axis )
134+ mean_of_square = jnp .mean (_square (x ), self .axis )
135+ if self .axis_name is not None :
136+ mean , mean_of_square = jnp .split (
137+ lax .pmean (jnp .concatenate ([mean , mean_of_square ]),
138+ axis_name = self .axis_name ,
139+ axis_index_groups = self .axis_index_groups ),
140+ 2
141+ )
142+ var = jnp .maximum (0. , mean_of_square - _square (mean ))
143+ self .running_mean .value = (self .momentum * self .running_mean + (1 - self .momentum ) * mean )
144+ self .running_var .value = (self .momentum * self .running_var + (1 - self .momentum ) * var )
143145 else :
144146 mean = self .running_mean .value
145147 var = self .running_var .value
@@ -488,7 +490,7 @@ def __init__(
488490 self .bias = bm .TrainVar (parameter (self .bias_initializer , self .normalized_shape ))
489491 self .scale = bm .TrainVar (parameter (self .scale_initializer , self .normalized_shape ))
490492
491- def update (self ,x ):
493+ def update (self , x ):
492494 if x .shape [- len (self .normalized_shape ):] != self .normalized_shape :
493495 raise ValueError (f'Expect the input shape should be (..., { ", " .join (self .normalized_shape )} ), '
494496 f'but we got { x .shape } ' )
@@ -629,6 +631,8 @@ def __init__(
629631 scale_initializer = scale_initializer ,
630632 mode = mode ,
631633 name = name )
634+
635+
632636BatchNorm1D = BatchNorm1d
633637BatchNorm2D = BatchNorm2d
634- BatchNorm3D = BatchNorm3d
638+ BatchNorm3D = BatchNorm3d
0 commit comments