1313# limitations under the License.
1414# ==============================================================================
1515"""The main training script."""
16- import multiprocessing
1716import os
1817from absl import app
1918from absl import flags
@@ -137,13 +136,12 @@ def main(_):
137136 tpu_cluster_resolver = None
138137
139138 # Check data path
140- if FLAGS .mode in ('train' ,
141- 'train_and_eval' ) and FLAGS .training_file_pattern is None :
142- raise RuntimeError ('You must specify --training_file_pattern for training .' )
139+ if FLAGS .mode in ('train' , 'train_and_eval' ):
140+ if FLAGS .training_file_pattern is None :
141+ raise RuntimeError ('Must specify --training_file_pattern for train .' )
143142 if FLAGS .mode in ('eval' , 'train_and_eval' ):
144143 if FLAGS .validation_file_pattern is None :
145- raise RuntimeError ('You must specify --validation_file_pattern '
146- 'for evaluation.' )
144+ raise RuntimeError ('Must specify --validation_file_pattern for eval.' )
147145
148146 # Parse and override hparams
149147 config = hparams_config .get_detection_config (FLAGS .model_name )
@@ -173,15 +171,6 @@ def main(_):
173171 'image_scales' : None ,
174172 }
175173 # The Input Partition Logic: We partition only the partition-able tensors.
176- # Spatial partition requires that the to-be-partitioned tensors must have a
177- # dimension that is a multiple of `partition_dims`. Depending on the
178- # `partition_dims` and the `image_size` and the `max_level` in config, some
179- # high-level anchor labels (i.e., `cls_targets` and `box_targets`) cannot
180- # be partitioned. For example, when `partition_dims` is [1, 4, 2, 1], image
181- # size is 1536, `max_level` is 9, `cls_targets_8` has a shape of
182- # [batch_size, 6, 6, 9], which cannot be partitioned (6 % 4 != 0). In this
183- # case, the level-8 and level-9 target tensors are not partition-able, and
184- # the highest partition-able level is 7.
185174 feat_sizes = utils .get_feat_sizes (
186175 config .get ('image_size' ), config .get ('max_level' ))
187176 for level in range (config .get ('min_level' ), config .get ('max_level' ) + 1 ):
@@ -254,56 +243,36 @@ def _can_partition(spatial_dim):
254243 model_fn_instance = det_model_fn .get_model_fn (FLAGS .model_name )
255244 max_instances_per_image = config .max_instances_per_image
256245 eval_steps = int (FLAGS .eval_samples // FLAGS .eval_batch_size )
246+ total_examples = int (config .num_epochs * FLAGS .num_examples_per_epoch )
247+ train_steps = total_examples // FLAGS .train_batch_size
257248 use_tpu = (FLAGS .strategy == 'tpu' )
258249 logging .info (params )
259250
260- def _train (steps ):
261- """Build train estimator and run training if steps > 0."""
262- train_estimator = tf .estimator .tpu .TPUEstimator (
263- model_fn = model_fn_instance ,
264- use_tpu = use_tpu ,
265- train_batch_size = FLAGS .train_batch_size ,
266- config = run_config ,
267- params = params )
268- train_estimator .train (
269- input_fn = dataloader .InputReader (
270- FLAGS .training_file_pattern ,
271- is_training = True ,
272- use_fake_data = FLAGS .use_fake_data ,
273- max_instances_per_image = max_instances_per_image ),
274- max_steps = steps )
275-
276- def _eval (steps ):
277- """Build estimator and eval the latest checkpoint if steps > 0."""
278- eval_params = dict (
279- params ,
280- strategy = FLAGS .strategy ,
281- input_rand_hflip = False ,
282- is_training_bn = False ,
283- )
284- eval_estimator = tf .estimator .tpu .TPUEstimator (
285- model_fn = model_fn_instance ,
286- use_tpu = use_tpu ,
287- train_batch_size = FLAGS .train_batch_size ,
288- eval_batch_size = FLAGS .eval_batch_size ,
289- config = run_config ,
290- params = eval_params )
291- eval_results = eval_estimator .evaluate (
292- input_fn = dataloader .InputReader (
293- FLAGS .validation_file_pattern ,
294- is_training = False ,
295- max_instances_per_image = max_instances_per_image ),
296- steps = steps ,
297- name = FLAGS .eval_name )
298- logging .info ('Evaluation results: %s' , eval_results )
299- return eval_results
251+ # Use the unified estimator, train, and eval interfaces.
252+ estimator = tf .estimator .tpu .TPUEstimator (
253+ model_fn = model_fn_instance ,
254+ use_tpu = use_tpu ,
255+ train_batch_size = FLAGS .train_batch_size ,
256+ eval_batch_size = FLAGS .eval_batch_size ,
257+ config = run_config ,
258+ params = params )
259+ train_input_fn = dataloader .InputReader (
260+ FLAGS .training_file_pattern ,
261+ is_training = True ,
262+ use_fake_data = FLAGS .use_fake_data ,
263+ max_instances_per_image = max_instances_per_image )
264+ eval_input_fn = dataloader .InputReader (
265+ FLAGS .validation_file_pattern ,
266+ is_training = False ,
267+ use_fake_data = FLAGS .use_fake_data ,
268+ max_instances_per_image = max_instances_per_image )
300269
301270 # start train/eval flow.
302271 if FLAGS .mode == 'train' :
303- total_examples = int (config .num_epochs * FLAGS .num_examples_per_epoch )
304- _train ( total_examples // FLAGS . train_batch_size )
272+ total_examples = int (config .num_epochs * FLAGS .num_examples_per_epoch ),
273+ estimator . train ( input_fn = train_input_fn , max_steps = train_steps )
305274 if FLAGS .eval_after_training :
306- _eval ( eval_steps )
275+ estimator . evaluate ( input_fn = eval_input_fn , steps = eval_steps )
307276
308277 elif FLAGS .mode == 'eval' :
309278 # Run evaluation when there's a new checkpoint
@@ -314,7 +283,7 @@ def _eval(steps):
314283
315284 logging .info ('Starting to evaluate.' )
316285 try :
317- eval_results = _eval ( eval_steps )
286+ eval_results = estimator . evaluate ( eval_input_fn , steps = eval_steps )
318287 # Terminate eval job when final checkpoint is reached.
319288 try :
320289 current_step = int (os .path .basename (ckpt ).split ('-' )[1 ])
@@ -323,53 +292,21 @@ def _eval(steps):
323292 break
324293
325294 utils .archive_ckpt (eval_results , eval_results ['AP' ], ckpt )
326- total_step = int ((config .num_epochs * FLAGS .num_examples_per_epoch ) /
327- FLAGS .train_batch_size )
328- if current_step >= total_step :
329- logging .info ('Evaluation finished after training step %d' ,
330- current_step )
295+ if current_step >= train_steps :
296+ logging .info ('Eval finished step %d/%d' , current_step , train_steps )
331297 break
332298
333299 except tf .errors .NotFoundError :
334- # Since the coordinator is on a different job than the TPU worker,
335- # sometimes the TPU worker does not finish initializing until long after
336- # the CPU job tells it to start evaluating. In this case, the checkpoint
337- # file could have been deleted already.
300+ # Checkpoint might be not already deleted by the time eval finished.
301+ # We simply skip ssuch case.
338302 logging .info ('Checkpoint %s no longer exists, skipping.' , ckpt )
339303
340304 elif FLAGS .mode == 'train_and_eval' :
341- ckpt = tf .train .latest_checkpoint (FLAGS .model_dir )
342- if not ckpt and FLAGS .ckpt :
343- # Load the pretrained ckpt from FLAGS.ckpt at the begining of training.
344- ckpt = tf .train .latest_checkpoint (FLAGS .ckpt )
345- try :
346- step = int (os .path .basename (ckpt ).split ('-' )[1 ])
347- current_epoch = (
348- step * FLAGS .train_batch_size // FLAGS .num_examples_per_epoch )
349- logging .info ('found ckpt at step %d (epoch %d)' , step , current_epoch )
350- except (IndexError , TypeError ):
351- logging .info ('Folder %s has no ckpt with valid step.' , FLAGS .model_dir )
352- current_epoch = 0
353-
354- def run_train_and_eval (e ):
355- print ('-----------------------------------------------------\n '
356- '=====> Starting training, epoch: %d.' % e )
357- _train (e * FLAGS .num_examples_per_epoch // FLAGS .train_batch_size )
358- print ('-----------------------------------------------------\n '
359- '=====> Starting evaluation, epoch: %d.' % e )
360- eval_results = _eval (eval_steps )
361- ckpt = tf .train .latest_checkpoint (FLAGS .model_dir )
362- utils .archive_ckpt (eval_results , eval_results ['AP' ], ckpt )
363-
364- epochs_per_cycle = 1 # higher number has less graph construction overhead.
365- for e in range (current_epoch + 1 , config .num_epochs + 1 , epochs_per_cycle ):
366- if FLAGS .run_epoch_in_child_process :
367- p = multiprocessing .Process (target = run_train_and_eval , args = (e ,))
368- p .start ()
369- p .join ()
370- else :
371- run_train_and_eval (e )
372-
305+ train_spec = tf .estimator .TrainSpec (
306+ input_fn = train_input_fn , max_steps = train_steps )
307+ eval_spec = tf .estimator .EvalSpec (
308+ input_fn = eval_input_fn , steps = eval_steps , throttle_secs = 600 )
309+ tf .estimator .train_and_evaluate (estimator , train_spec , eval_spec )
373310 else :
374311 logging .info ('Invalid mode: %s' , FLAGS .mode )
375312
0 commit comments