@@ -209,11 +209,11 @@ def __init__(self, fused=False, **kwargs):
209209 kwargs ['name' ] = 'tpu_batch_normalization'
210210 if fused in (True , None ):
211211 raise ValueError ('TpuBatchNormalization does not support fused=True.' )
212- super (TpuBatchNormalization , self ).__init__ (fused = fused , ** kwargs )
212+ super ().__init__ (fused = fused , ** kwargs )
213213
214214 def _moments (self , inputs , reduction_axes , keep_dims ):
215215 """Compute the mean and variance: it overrides the original _moments."""
216- shard_mean , shard_variance = super (TpuBatchNormalization , self )._moments (
216+ shard_mean , shard_variance = super ()._moments (
217217 inputs , reduction_axes , keep_dims = keep_dims )
218218
219219 num_shards = tpu_function .get_tpu_context ().number_of_shards or 1
@@ -233,23 +233,44 @@ def _moments(self, inputs, reduction_axes, keep_dims):
233233 return (shard_mean , shard_variance )
234234
235235 def call (self , inputs , training = None ):
236- outputs = super (TpuBatchNormalization , self ).call (inputs , training )
236+ outputs = super ().call (inputs , training )
237237 # A temporary hack for tf1 compatibility with keras batch norm.
238238 for u in self .updates :
239239 tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , u )
240240 return outputs
241241
242242
243- class SyncBatchNormalization (tf2 .keras .layers .experimental . SyncBatchNormalization ):
243+ class SyncBatchNormalization (tf .keras .layers .BatchNormalization ):
244244 """Cross replica batch normalization."""
245-
246- def __init__ (self , ** kwargs ):
245+ def __init__ (self , fused = False , ** kwargs ):
247246 if not kwargs .get ('name' , None ):
248247 kwargs ['name' ] = 'tpu_batch_normalization'
249- super (SyncBatchNormalization , self ).__init__ (** kwargs )
248+ if fused in (True , None ):
249+ raise ValueError ('SyncBatchNormalization does not support fused=True.' )
250+ super ().__init__ (fused = fused , ** kwargs )
251+
252+ def _moments (self , inputs , reduction_axes , keep_dims ):
253+ """Compute the mean and variance: it overrides the original _moments."""
254+ shard_mean , shard_variance = super ()._moments (
255+ inputs , reduction_axes , keep_dims = keep_dims )
256+
257+ replica_context = tf .distribute .get_replica_context ()
258+ num_shards = replica_context .num_replicas_in_sync or 1
259+
260+ if num_shards > 1 :
261+ # Compute variance using: Var[X]= E[X^2] - E[X]^2.
262+ shard_square_of_mean = tf .math .square (shard_mean )
263+ shard_mean_of_square = shard_variance + shard_square_of_mean
264+ shard_stack = tf .stack ([shard_mean , shard_mean_of_square ])
265+ group_mean , group_mean_of_square = tf .unstack (
266+ replica_context .all_reduce (tf .distribute .ReduceOp .MEAN , shard_stack ))
267+ group_variance = group_mean_of_square - tf .math .square (group_mean )
268+ return (group_mean , group_variance )
269+ else :
270+ return (shard_mean , shard_variance )
250271
251272 def call (self , inputs , training = None ):
252- outputs = super (SyncBatchNormalization , self ).call (inputs , training )
273+ outputs = super ().call (inputs , training )
253274 # A temporary hack for tf1 compatibility with keras batch norm.
254275 for u in self .updates :
255276 tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , u )
@@ -262,10 +283,10 @@ class BatchNormalization(tf.keras.layers.BatchNormalization):
262283 def __init__ (self , ** kwargs ):
263284 if not kwargs .get ('name' , None ):
264285 kwargs ['name' ] = 'tpu_batch_normalization'
265- super (BatchNormalization , self ).__init__ (** kwargs )
286+ super ().__init__ (** kwargs )
266287
267288 def call (self , inputs , training = None ):
268- outputs = super (BatchNormalization , self ).call (inputs , training )
289+ outputs = super ().call (inputs , training )
269290 # A temporary hack for tf1 compatibility with keras batch norm.
270291 for u in self .updates :
271292 tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , u )
@@ -384,7 +405,7 @@ def num_params_flops(readable_format=True):
384405class Pair (tuple ):
385406
386407 def __new__ (cls , name , value ):
387- return super (Pair , cls ).__new__ (cls , (name , value ))
408+ return super ().__new__ (cls , (name , value ))
388409
389410 def __init__ (self , name , _ ): # pylint: disable=super-init-not-called
390411 self .name = name
0 commit comments