Skip to content

Commit 96db246

Browse files
committed
fix determinism
1 parent d7c9f1b commit 96db246

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

efficientdet/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ numpy>=1.19.4
55
Pillow>=6.0.0
66
PyYAML>=5.1
77
six>=1.15.0
8-
tensorflow>=2.7.0
9-
tensorflow-addons>=0.15
8+
tensorflow>=2.8.0
9+
tensorflow-addons>=0.16.1
1010
tensorflow-hub>=0.11
1111
neural-structured-learning>=1.3.1
1212
Cython>=0.29.13

efficientdet/tf2/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ def main(_):
164164
tf.config.experimental.set_memory_growth(gpu, True)
165165

166166
if FLAGS.tf_random_seed:
167-
tf.random.set_seed(FLAGS.tf_random_seed)
167+
tf.keras.utils.set_random_seed(FLAGS.tf_random_seed)
168+
tf.config.experimental.enable_op_determinism()
168169

169170
if FLAGS.debug:
170171
tf.debugging.set_log_device_placement(True)
171-
os.environ['TF_DETERMINISTIC_OPS'] = '1'
172172
logging.set_verbosity(logging.DEBUG)
173173

174174
if FLAGS.strategy == 'tpu':
@@ -241,9 +241,11 @@ def get_dataset(is_training, config):
241241
else:
242242
model = train_lib.EfficientDetNetTrain(config=config)
243243
model = setup_model(model, config)
244+
244245
if FLAGS.debug:
245246
tf.data.experimental.enable_debug_mode()
246247
tf.config.run_functions_eagerly(True)
248+
247249
if tf.train.latest_checkpoint(FLAGS.model_dir):
248250
ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
249251
util_keras.restore_ckpt(

0 commit comments

Comments
 (0)