Skip to content

Commit 38791da

Browse files
committed
fix syncbatchnorm
1 parent 870c510 commit 38791da

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

efficientdet/det_model_fn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def reg_l2_loss(weight_decay, regex=r'.*(kernel|weight):0$'):
319319
])
320320

321321

322+
@tf.autograph.experimental.do_not_convert
322323
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
323324
"""Model definition entry.
324325
@@ -574,7 +575,6 @@ def scaffold_fn():
574575
tf.train.init_from_checkpoint(checkpoint, var_map)
575576
return tf.train.Scaffold()
576577
elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:
577-
578578
def scaffold_fn():
579579
"""Load moving average variables for eval."""
580580
logging.info('Load EMA vars with ema_decay=%f', moving_average_decay)
@@ -622,6 +622,7 @@ def before_run(self, run_context):
622622
training_hooks=training_hooks)
623623
else:
624624
eval_metric_ops = eval_metrics[0](eval_metrics[1]) if eval_metrics else None
625+
utils.get_tpu_host_call(global_step, params)
625626
return tf.estimator.EstimatorSpec(
626627
mode=mode,
627628
loss=total_loss,

efficientdet/utils.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ def __init__(self, fused=False, **kwargs):
209209
kwargs['name'] = 'tpu_batch_normalization'
210210
if fused in (True, None):
211211
raise ValueError('TpuBatchNormalization does not support fused=True.')
212-
super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
212+
super().__init__(fused=fused, **kwargs)
213213

214214
def _moments(self, inputs, reduction_axes, keep_dims):
215215
"""Compute the mean and variance: it overrides the original _moments."""
216-
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
216+
shard_mean, shard_variance = super()._moments(
217217
inputs, reduction_axes, keep_dims=keep_dims)
218218

219219
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
@@ -233,23 +233,44 @@ def _moments(self, inputs, reduction_axes, keep_dims):
233233
return (shard_mean, shard_variance)
234234

235235
def call(self, inputs, training=None):
236-
outputs = super(TpuBatchNormalization, self).call(inputs, training)
236+
outputs = super().call(inputs, training)
237237
# A temporary hack for tf1 compatibility with keras batch norm.
238238
for u in self.updates:
239239
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u)
240240
return outputs
241241

242242

243-
class SyncBatchNormalization(tf2.keras.layers.experimental.SyncBatchNormalization):
243+
class SyncBatchNormalization(tf.keras.layers.BatchNormalization):
244244
"""Cross replica batch normalization."""
245-
246-
def __init__(self, **kwargs):
245+
def __init__(self, fused=False, **kwargs):
247246
if not kwargs.get('name', None):
248247
kwargs['name'] = 'tpu_batch_normalization'
249-
super(SyncBatchNormalization, self).__init__(**kwargs)
248+
if fused in (True, None):
249+
raise ValueError('SyncBatchNormalization does not support fused=True.')
250+
super().__init__(fused=fused, **kwargs)
251+
252+
def _moments(self, inputs, reduction_axes, keep_dims):
253+
"""Compute the mean and variance: it overrides the original _moments."""
254+
shard_mean, shard_variance = super()._moments(
255+
inputs, reduction_axes, keep_dims=keep_dims)
256+
257+
replica_context = tf.distribute.get_replica_context()
258+
num_shards = replica_context.num_replicas_in_sync or 1
259+
260+
if num_shards > 1:
261+
# Compute variance using: Var[X]= E[X^2] - E[X]^2.
262+
shard_square_of_mean = tf.math.square(shard_mean)
263+
shard_mean_of_square = shard_variance + shard_square_of_mean
264+
shard_stack = tf.stack([shard_mean, shard_mean_of_square])
265+
group_mean, group_mean_of_square = tf.unstack(
266+
replica_context.all_reduce(tf.distribute.ReduceOp.MEAN, shard_stack))
267+
group_variance = group_mean_of_square - tf.math.square(group_mean)
268+
return (group_mean, group_variance)
269+
else:
270+
return (shard_mean, shard_variance)
250271

251272
def call(self, inputs, training=None):
252-
outputs = super(SyncBatchNormalization, self).call(inputs, training)
273+
outputs = super().call(inputs, training)
253274
# A temporary hack for tf1 compatibility with keras batch norm.
254275
for u in self.updates:
255276
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u)
@@ -262,10 +283,10 @@ class BatchNormalization(tf.keras.layers.BatchNormalization):
262283
def __init__(self, **kwargs):
263284
if not kwargs.get('name', None):
264285
kwargs['name'] = 'tpu_batch_normalization'
265-
super(BatchNormalization, self).__init__(**kwargs)
286+
super().__init__(**kwargs)
266287

267288
def call(self, inputs, training=None):
268-
outputs = super(BatchNormalization, self).call(inputs, training)
289+
outputs = super().call(inputs, training)
269290
# A temporary hack for tf1 compatibility with keras batch norm.
270291
for u in self.updates:
271292
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u)
@@ -384,7 +405,7 @@ def num_params_flops(readable_format=True):
384405
class Pair(tuple):
385406

386407
def __new__(cls, name, value):
387-
return super(Pair, cls).__new__(cls, (name, value))
408+
return super().__new__(cls, (name, value))
388409

389410
def __init__(self, name, _): # pylint: disable=super-init-not-called
390411
self.name = name

0 commit comments

Comments
 (0)