Skip to content

Commit cdec8d6

Browse files
authored
Merge pull request #742 from google/est
Simplify estimator train_and_eval.
2 parents e03aaee + f736944 commit cdec8d6

File tree

2 files changed

+38
-100
lines changed

2 files changed

+38
-100
lines changed

efficientdet/det_model_fn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
343343
"""
344344
utils.image('input_image', features)
345345
training_hooks = []
346+
params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)
346347

347348
if params['use_keras_model']:
348349
def model_fn(inputs):

efficientdet/main.py

Lines changed: 37 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""The main training script."""
16-
import multiprocessing
1716
import os
1817
from absl import app
1918
from absl import flags
@@ -137,13 +136,12 @@ def main(_):
137136
tpu_cluster_resolver = None
138137

139138
# Check data path
140-
if FLAGS.mode in ('train',
141-
'train_and_eval') and FLAGS.training_file_pattern is None:
142-
raise RuntimeError('You must specify --training_file_pattern for training.')
139+
if FLAGS.mode in ('train', 'train_and_eval'):
140+
if FLAGS.training_file_pattern is None:
141+
raise RuntimeError('Must specify --training_file_pattern for train.')
143142
if FLAGS.mode in ('eval', 'train_and_eval'):
144143
if FLAGS.validation_file_pattern is None:
145-
raise RuntimeError('You must specify --validation_file_pattern '
146-
'for evaluation.')
144+
raise RuntimeError('Must specify --validation_file_pattern for eval.')
147145

148146
# Parse and override hparams
149147
config = hparams_config.get_detection_config(FLAGS.model_name)
@@ -173,15 +171,6 @@ def main(_):
173171
'image_scales': None,
174172
}
175173
# The Input Partition Logic: We partition only the partition-able tensors.
176-
# Spatial partition requires that the to-be-partitioned tensors must have a
177-
# dimension that is a multiple of `partition_dims`. Depending on the
178-
# `partition_dims` and the `image_size` and the `max_level` in config, some
179-
# high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
180-
# be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
181-
# size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
182-
# [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
183-
# case, the level-8 and level-9 target tensors are not partition-able, and
184-
# the highest partition-able level is 7.
185174
feat_sizes = utils.get_feat_sizes(
186175
config.get('image_size'), config.get('max_level'))
187176
for level in range(config.get('min_level'), config.get('max_level') + 1):
@@ -254,56 +243,36 @@ def _can_partition(spatial_dim):
254243
model_fn_instance = det_model_fn.get_model_fn(FLAGS.model_name)
255244
max_instances_per_image = config.max_instances_per_image
256245
eval_steps = int(FLAGS.eval_samples // FLAGS.eval_batch_size)
246+
total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch)
247+
train_steps = total_examples // FLAGS.train_batch_size
257248
use_tpu = (FLAGS.strategy == 'tpu')
258249
logging.info(params)
259250

260-
def _train(steps):
261-
"""Build train estimator and run training if steps > 0."""
262-
train_estimator = tf.estimator.tpu.TPUEstimator(
263-
model_fn=model_fn_instance,
264-
use_tpu=use_tpu,
265-
train_batch_size=FLAGS.train_batch_size,
266-
config=run_config,
267-
params=params)
268-
train_estimator.train(
269-
input_fn=dataloader.InputReader(
270-
FLAGS.training_file_pattern,
271-
is_training=True,
272-
use_fake_data=FLAGS.use_fake_data,
273-
max_instances_per_image=max_instances_per_image),
274-
max_steps=steps)
275-
276-
def _eval(steps):
277-
"""Build estimator and eval the latest checkpoint if steps > 0."""
278-
eval_params = dict(
279-
params,
280-
strategy=FLAGS.strategy,
281-
input_rand_hflip=False,
282-
is_training_bn=False,
283-
)
284-
eval_estimator = tf.estimator.tpu.TPUEstimator(
285-
model_fn=model_fn_instance,
286-
use_tpu=use_tpu,
287-
train_batch_size=FLAGS.train_batch_size,
288-
eval_batch_size=FLAGS.eval_batch_size,
289-
config=run_config,
290-
params=eval_params)
291-
eval_results = eval_estimator.evaluate(
292-
input_fn=dataloader.InputReader(
293-
FLAGS.validation_file_pattern,
294-
is_training=False,
295-
max_instances_per_image=max_instances_per_image),
296-
steps=steps,
297-
name=FLAGS.eval_name)
298-
logging.info('Evaluation results: %s', eval_results)
299-
return eval_results
251+
# Use the unified estimator, train, and eval interfaces.
252+
estimator = tf.estimator.tpu.TPUEstimator(
253+
model_fn=model_fn_instance,
254+
use_tpu=use_tpu,
255+
train_batch_size=FLAGS.train_batch_size,
256+
eval_batch_size=FLAGS.eval_batch_size,
257+
config=run_config,
258+
params=params)
259+
train_input_fn = dataloader.InputReader(
260+
FLAGS.training_file_pattern,
261+
is_training=True,
262+
use_fake_data=FLAGS.use_fake_data,
263+
max_instances_per_image=max_instances_per_image)
264+
eval_input_fn = dataloader.InputReader(
265+
FLAGS.validation_file_pattern,
266+
is_training=False,
267+
use_fake_data=FLAGS.use_fake_data,
268+
max_instances_per_image=max_instances_per_image)
300269

301270
# start train/eval flow.
302271
if FLAGS.mode == 'train':
303-
total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch)
304-
_train(total_examples // FLAGS.train_batch_size)
272+
total_examples = int(config.num_epochs * FLAGS.num_examples_per_epoch),
273+
estimator.train(input_fn=train_input_fn, max_steps=train_steps)
305274
if FLAGS.eval_after_training:
306-
_eval(eval_steps)
275+
estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
307276

308277
elif FLAGS.mode == 'eval':
309278
# Run evaluation when there's a new checkpoint
@@ -314,7 +283,7 @@ def _eval(steps):
314283

315284
logging.info('Starting to evaluate.')
316285
try:
317-
eval_results = _eval(eval_steps)
286+
eval_results = estimator.evaluate(eval_input_fn, steps=eval_steps)
318287
# Terminate eval job when final checkpoint is reached.
319288
try:
320289
current_step = int(os.path.basename(ckpt).split('-')[1])
@@ -323,53 +292,21 @@ def _eval(steps):
323292
break
324293

325294
utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
326-
total_step = int((config.num_epochs * FLAGS.num_examples_per_epoch) /
327-
FLAGS.train_batch_size)
328-
if current_step >= total_step:
329-
logging.info('Evaluation finished after training step %d',
330-
current_step)
295+
if current_step >= train_steps:
296+
logging.info('Eval finished step %d/%d', current_step, train_steps)
331297
break
332298

333299
except tf.errors.NotFoundError:
334-
# Since the coordinator is on a different job than the TPU worker,
335-
# sometimes the TPU worker does not finish initializing until long after
336-
# the CPU job tells it to start evaluating. In this case, the checkpoint
337-
# file could have been deleted already.
300+
# Checkpoint might be not already deleted by the time eval finished.
301+
# We simply skip ssuch case.
338302
logging.info('Checkpoint %s no longer exists, skipping.', ckpt)
339303

340304
elif FLAGS.mode == 'train_and_eval':
341-
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
342-
if not ckpt and FLAGS.ckpt:
343-
# Load the pretrained ckpt from FLAGS.ckpt at the begining of training.
344-
ckpt = tf.train.latest_checkpoint(FLAGS.ckpt)
345-
try:
346-
step = int(os.path.basename(ckpt).split('-')[1])
347-
current_epoch = (
348-
step * FLAGS.train_batch_size // FLAGS.num_examples_per_epoch)
349-
logging.info('found ckpt at step %d (epoch %d)', step, current_epoch)
350-
except (IndexError, TypeError):
351-
logging.info('Folder %s has no ckpt with valid step.', FLAGS.model_dir)
352-
current_epoch = 0
353-
354-
def run_train_and_eval(e):
355-
print('-----------------------------------------------------\n'
356-
'=====> Starting training, epoch: %d.' % e)
357-
_train(e * FLAGS.num_examples_per_epoch // FLAGS.train_batch_size)
358-
print('-----------------------------------------------------\n'
359-
'=====> Starting evaluation, epoch: %d.' % e)
360-
eval_results = _eval(eval_steps)
361-
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
362-
utils.archive_ckpt(eval_results, eval_results['AP'], ckpt)
363-
364-
epochs_per_cycle = 1 # higher number has less graph construction overhead.
365-
for e in range(current_epoch + 1, config.num_epochs + 1, epochs_per_cycle):
366-
if FLAGS.run_epoch_in_child_process:
367-
p = multiprocessing.Process(target=run_train_and_eval, args=(e,))
368-
p.start()
369-
p.join()
370-
else:
371-
run_train_and_eval(e)
372-
305+
train_spec = tf.estimator.TrainSpec(
306+
input_fn=train_input_fn, max_steps=train_steps)
307+
eval_spec = tf.estimator.EvalSpec(
308+
input_fn=eval_input_fn, steps=eval_steps, throttle_secs=600)
309+
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
373310
else:
374311
logging.info('Invalid mode: %s', FLAGS.mode)
375312

0 commit comments

Comments
 (0)