Skip to content

Commit 955113b

Browse files
committed
fix update ops
1 parent ea9d3c5 commit 955113b

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

efficientdet/keras/util_keras.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _wrapper(*args, **kwargs):
183183
return bn_class(*args, **kwargs)
184184
return _wrapper
185185

186-
187-
utils.BatchNormalization = get_batch_norm(tf.keras.layers.BatchNormalization)
188-
utils.SyncBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
189-
utils.TpuBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
186+
if tf.compat.v1.executing_eagerly_outside_functions():
187+
utils.BatchNormalization = get_batch_norm(tf.keras.layers.BatchNormalization)
188+
utils.SyncBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
189+
utils.TpuBatchNormalization = get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)

efficientdet/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _filter_fn(item):
5757

5858
custom_gradient.get_variable_by_name = get_variable_by_name
5959
import tensorflow.compat.v1 as tf
60-
60+
tf.disable_eager_execution()
6161
import dataloader
6262
import det_model_fn
6363
import hparams_config
@@ -158,7 +158,6 @@ def _filter_fn(item):
158158

159159
def main(_):
160160
if FLAGS.strategy == 'tpu':
161-
tf.disable_eager_execution()
162161
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
163162
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
164163
tpu_grpc_url = tpu_cluster_resolver.get_master()

0 commit comments

Comments
 (0)