@@ -271,10 +271,10 @@ def _common_image_process(self, image, classes, boxes, data, params):
271271 from aug import autoaugment # pylint: disable=g-import-not-at-top
272272 if params ['autoaugment_policy' ] == 'randaug' :
273273 image , boxes = autoaugment .distort_image_with_randaugment (
274- image , boxes , num_layers = 1 , magnitude = 15 )
274+ image , boxes , num_layers = 1 , magnitude = 15 )
275275 else :
276276 image , boxes = autoaugment .distort_image_with_autoaugment (
277- image , boxes , params ['autoaugment_policy' ])
277+ image , boxes , params ['autoaugment_policy' ])
278278 return image , boxes , classes
279279
280280 def _resize_image_first (self , image , classes , boxes , data , params ):
@@ -285,24 +285,26 @@ def _resize_image_first(self, image, classes, boxes, data, params):
285285 input_processor .random_horizontal_flip ()
286286
287287 input_processor .set_training_random_scale_factors (
288- params ['jitter_min' ], params ['jitter_max' ],
289- params .get ('target_size' , None ))
288+ params ['jitter_min' ], params ['jitter_max' ],
289+ params .get ('target_size' , None ))
290290 else :
291291 input_processor .set_scale_factors_to_output_size ()
292292
293293 image = input_processor .resize_and_crop_image ()
294294 boxes , classes = input_processor .resize_and_crop_boxes ()
295295
296296 if self ._is_training :
297- image , boxes , classes = self ._common_image_process (image , classes , boxes , data , params )
297+ image , boxes , classes = self ._common_image_process (image , classes ,
298+ boxes , data , params )
298299
299300 input_processor .image = image
300301 image = input_processor .normalize_image ()
301302 return image , boxes , classes , input_processor .image_scale_to_original
302303
303304 def _resize_image_last (self , image , classes , boxes , data , params ):
304305 if self ._is_training :
305- image , boxes , classes = self ._common_image_process (image , classes , boxes , data , params )
306+ image , boxes , classes = self ._common_image_process (image , classes ,
307+ boxes , data , params )
306308
307309 input_processor = DetectionInputProcessor (image , params ['image_size' ],
308310 boxes , classes )
@@ -311,8 +313,8 @@ def _resize_image_last(self, image, classes, boxes, data, params):
311313 input_processor .random_horizontal_flip ()
312314
313315 input_processor .set_training_random_scale_factors (
314- params ['jitter_min' ], params ['jitter_max' ],
315- params .get ('target_size' , None ))
316+ params ['jitter_min' ], params ['jitter_max' ],
317+ params .get ('target_size' , None ))
316318 else :
317319 input_processor .set_scale_factors_to_output_size ()
318320 input_processor .normalize_image ()
@@ -367,15 +369,15 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
367369 areas = data ['groundtruth_area' ]
368370 is_crowds = data ['groundtruth_is_crowd' ]
369371 image_masks = data .get ('groundtruth_instance_masks' , [])
370- classes = tf .reshape (tf .cast (classes , dtype = tf .float32 ), [- 1 , 1 ])
371372 source_area = tf .shape (image )[0 ] * tf .shape (image )[1 ]
372373 target_size = utils .parse_image_size (params ['image_size' ])
373374 target_area = target_size [0 ] * target_size [1 ]
374375 # set condition in order to always process small
375376 # first which could speed up pipeline
376- image , boxes , classes , image_scale = tf .cond (source_area > target_area ,
377- lambda : self ._resize_image_first (image , classes , boxes , data , params ),
378- lambda : self ._resize_image_last (image , classes , boxes , data , params ))
377+ image , boxes , classes , image_scale = tf .cond (
378+ source_area > target_area ,
379+ lambda : self ._resize_image_first (image , classes , boxes , data , params ),
380+ lambda : self ._resize_image_last (image , classes , boxes , data , params ))
379381
380382 # Assign anchors.
381383 (cls_targets , box_targets ,
@@ -395,8 +397,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
395397 classes = pad_to_fixed_size (classes , - 1 ,
396398 [self ._max_instances_per_image , 1 ])
397399 if params ['mixed_precision' ]:
398- dtype = (
399- tf .keras .mixed_precision .global_policy ().compute_dtype )
400+ dtype = tf .keras .mixed_precision .global_policy ().compute_dtype
400401 image = tf .cast (image , dtype = dtype )
401402 box_targets = tf .nest .map_structure (
402403 lambda box_target : tf .cast (box_target , dtype = dtype ), box_targets )
@@ -460,8 +461,6 @@ def __call__(self, params, input_context=None, batch_size=None):
460461 seed = params ['tf_random_seed' ] if self ._debug else None
461462 dataset = tf .data .Dataset .list_files (
462463 self ._file_pattern , shuffle = self ._is_training , seed = seed )
463- if self ._is_training :
464- dataset = dataset .repeat ()
465464 if input_context :
466465 dataset = dataset .shard (input_context .num_input_pipelines ,
467466 input_context .input_pipeline_id )
@@ -495,6 +494,8 @@ def _prefetch_dataset(filename):
495494 dataset = dataset .map (
496495 lambda * args : self .process_example (params , batch_size , * args ))
497496 dataset = dataset .prefetch (tf .data .AUTOTUNE )
497+ if self ._is_training :
498+ dataset = dataset .repeat ()
498499 if self ._use_fake_data :
499500 # Turn this dataset into a semi-fake dataset which always loop at the
500501 # first batch. This reduces variance in performance and is useful in
0 commit comments