Skip to content

Commit ea9d3c5

Browse files
committed
do not rewrite estimator batch norm
1 parent c2e3f55 commit ea9d3c5

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

efficientdet/hparams_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def merge_dict_recursive(target, src):
145145
"""Recursively merge two nested dictionary."""
146146
for k in src.keys():
147147
if ((k in target and isinstance(target[k], dict) and
148-
isinstance(src[k], collections.abc.Mapping))):
148+
isinstance(src[k], collections.Mapping))):
149149
merge_dict_recursive(target[k], src[k])
150150
else:
151151
target[k] = src[k]

efficientdet/keras/efficientdet_keras.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@
2727
from keras import postprocess
2828
from keras import tfmot
2929
from keras import util_keras
30-
# pylint: disable=arguments-differ # fo keras layers.
31-
utils.BatchNormalization = util_keras.get_batch_norm(tf.keras.layers.BatchNormalization)
32-
utils.SyncBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
33-
utils.TpuBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
30+
3431

3532
def add_n(nodes):
3633
"""A customized add_n to add up a list of tensors."""

efficientdet/keras/util_keras.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,15 @@ def fp16_to_fp32_nested(input_nested):
175175
return input_nested
176176
return out_tensor_dict
177177

178+
178179
def get_batch_norm(bn_class):
179180
def _wrapper(*args, **kwargs):
180181
if not kwargs.get('name', None):
181182
kwargs['name'] = 'tpu_batch_normalization'
182183
return bn_class(*args, **kwargs)
183-
return _wrapper
184+
return _wrapper
185+
186+
187+
utils.BatchNormalization = get_batch_norm(tf.keras.layers.BatchNormalization)
188+
utils.SyncBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
189+
utils.TpuBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)

0 commit comments

Comments
 (0)