@@ -48,6 +48,14 @@ def __init__(self, image, output_size):
4848 self ._crop_offset_y = tf .constant (0 )
4949 self ._crop_offset_x = tf .constant (0 )
5050
51+ @property
52+ def image (self ):
53+ return self ._image
54+
55+ @image .setter
56+ def image (self , image ):
57+ self ._image = image
58+
5159 def normalize_image (self ):
5260 """Normalize the image to zero mean and unit variance."""
5361 # The image normalization is identical to Cloud TPU ResNet.
@@ -61,6 +69,7 @@ def normalize_image(self):
6169 scale = tf .expand_dims (scale , axis = 0 )
6270 scale = tf .expand_dims (scale , axis = 0 )
6371 self ._image /= scale
72+ return self ._image
6473
6574 def set_training_random_scale_factors (self ,
6675 scale_min ,
@@ -126,6 +135,7 @@ def set_scale_factors_to_output_size(self):
126135
127136 def resize_and_crop_image (self , method = tf .image .ResizeMethod .BILINEAR ):
128137 """Resize input image and crop it to the self._output dimension."""
138+ dtype = self ._image .dtype
129139 scaled_image = tf .image .resize (
130140 self ._image , [self ._scaled_height , self ._scaled_width ], method = method )
131141 scaled_image = scaled_image [self ._crop_offset_y :self ._crop_offset_y +
@@ -135,7 +145,8 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
135145 output_image = tf .image .pad_to_bounding_box (scaled_image , 0 , 0 ,
136146 self ._output_size [0 ],
137147 self ._output_size [1 ])
138- return output_image
148+ self ._image = tf .cast (output_image , dtype )
149+ return self ._image
139150
140151
141152class DetectionInputProcessor (InputProcessor ):
@@ -245,6 +256,70 @@ def __init__(self,
245256 self ._max_instances_per_image = max_instances_per_image or 100
246257 self ._debug = debug
247258
259+ def _common_image_process (self , image , classes , boxes , data , params ):
260+ # Training time preprocessing.
261+ if params ['skip_crowd_during_training' ]:
262+ indices = tf .where (tf .logical_not (data ['groundtruth_is_crowd' ]))
263+ classes = tf .gather_nd (classes , indices )
264+ boxes = tf .gather_nd (boxes , indices )
265+
266+ if params .get ('grid_mask' , None ):
267+ from aug import gridmask # pylint: disable=g-import-not-at-top
268+ image , boxes = gridmask .gridmask (image , boxes )
269+
270+ if params .get ('autoaugment_policy' , None ):
271+ from aug import autoaugment # pylint: disable=g-import-not-at-top
272+ if params ['autoaugment_policy' ] == 'randaug' :
273+ image , boxes = autoaugment .distort_image_with_randaugment (
274+ image , boxes , num_layers = 1 , magnitude = 15 )
275+ else :
276+ image , boxes = autoaugment .distort_image_with_autoaugment (
277+ image , boxes , params ['autoaugment_policy' ])
278+ return image , boxes , classes
279+
280+ def _resize_image_first (self , image , classes , boxes , data , params ):
281+ input_processor = DetectionInputProcessor (image , params ['image_size' ],
282+ boxes , classes )
283+ if self ._is_training :
284+ if params ['input_rand_hflip' ]:
285+ input_processor .random_horizontal_flip ()
286+
287+ input_processor .set_training_random_scale_factors (
288+ params ['jitter_min' ], params ['jitter_max' ],
289+ params .get ('target_size' , None ))
290+ else :
291+ input_processor .set_scale_factors_to_output_size ()
292+
293+ image = input_processor .resize_and_crop_image ()
294+ boxes , classes = input_processor .resize_and_crop_boxes ()
295+
296+ if self ._is_training :
297+ image , boxes , classes = self ._common_image_process (image , classes , boxes , data , params )
298+
299+ input_processor .image = image
300+ image = input_processor .normalize_image ()
301+ return image , boxes , classes , input_processor .image_scale_to_original
302+
303+ def _resize_image_last (self , image , classes , boxes , data , params ):
304+ if self ._is_training :
305+ image , boxes , classes = self ._common_image_process (image , classes , boxes , data , params )
306+
307+ input_processor = DetectionInputProcessor (image , params ['image_size' ],
308+ boxes , classes )
309+ if self ._is_training :
310+ if params ['input_rand_hflip' ]:
311+ input_processor .random_horizontal_flip ()
312+
313+ input_processor .set_training_random_scale_factors (
314+ params ['jitter_min' ], params ['jitter_max' ],
315+ params .get ('target_size' , None ))
316+ else :
317+ input_processor .set_scale_factors_to_output_size ()
318+ input_processor .normalize_image ()
319+ image = input_processor .resize_and_crop_image ()
320+ boxes , classes = input_processor .resize_and_crop_boxes ()
321+ return image , boxes , classes , input_processor .image_scale_to_original
322+
248323 @tf .autograph .experimental .do_not_convert
249324 def dataset_parser (self , value , example_decoder , anchor_labeler , params ):
250325 """Parse data to a fixed dimension input image and learning targets.
@@ -293,41 +368,14 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
293368 is_crowds = data ['groundtruth_is_crowd' ]
294369 image_masks = data .get ('groundtruth_instance_masks' , [])
295370 classes = tf .reshape (tf .cast (classes , dtype = tf .float32 ), [- 1 , 1 ])
296-
297- if self ._is_training :
298- # Training time preprocessing.
299- if params ['skip_crowd_during_training' ]:
300- indices = tf .where (tf .logical_not (data ['groundtruth_is_crowd' ]))
301- classes = tf .gather_nd (classes , indices )
302- boxes = tf .gather_nd (boxes , indices )
303-
304- if params .get ('grid_mask' , None ):
305- from aug import gridmask # pylint: disable=g-import-not-at-top
306- image , boxes = gridmask .gridmask (image , boxes )
307-
308- if params .get ('autoaugment_policy' , None ):
309- from aug import autoaugment # pylint: disable=g-import-not-at-top
310- if params ['autoaugment_policy' ] == 'randaug' :
311- image , boxes = autoaugment .distort_image_with_randaugment (
312- image , boxes , num_layers = 1 , magnitude = 15 )
313- else :
314- image , boxes = autoaugment .distort_image_with_autoaugment (
315- image , boxes , params ['autoaugment_policy' ])
316-
317- input_processor = DetectionInputProcessor (image , params ['image_size' ],
318- boxes , classes )
319- input_processor .normalize_image ()
320- if self ._is_training :
321- if params ['input_rand_hflip' ]:
322- input_processor .random_horizontal_flip ()
323-
324- input_processor .set_training_random_scale_factors (
325- params ['jitter_min' ], params ['jitter_max' ],
326- params .get ('target_size' , None ))
327- else :
328- input_processor .set_scale_factors_to_output_size ()
329- image = input_processor .resize_and_crop_image ()
330- boxes , classes = input_processor .resize_and_crop_boxes ()
371+ source_area = tf .shape (image )[0 ] * tf .shape (image )[1 ]
372+ target_size = utils .parse_image_size (params ['image_size' ])
373+ target_area = target_size [0 ] * target_size [1 ]
374+ # set condition in order to always process small
375+ # first which could speed up pipeline
376+ image , boxes , classes , image_scale = tf .cond (source_area > target_area ,
377+ lambda : self ._resize_image_first (image , classes , boxes , data , params ),
378+ lambda : self ._resize_image_last (image , classes , boxes , data , params ))
331379
332380 # Assign anchors.
333381 (cls_targets , box_targets ,
@@ -338,7 +386,6 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
338386 source_id = tf .strings .to_number (source_id )
339387
340388 # Pad groundtruth data for evaluation.
341- image_scale = input_processor .image_scale_to_original
342389 boxes *= image_scale
343390 is_crowds = tf .cast (is_crowds , dtype = tf .float32 )
344391 boxes = pad_to_fixed_size (boxes , - 1 , [self ._max_instances_per_image , 4 ])
@@ -349,7 +396,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
349396 [self ._max_instances_per_image , 1 ])
350397 if params ['mixed_precision' ]:
351398 dtype = (
352- tf .keras .mixed_precision .experimental . global_policy ().compute_dtype )
399+ tf .keras .mixed_precision .global_policy ().compute_dtype )
353400 image = tf .cast (image , dtype = dtype )
354401 box_targets = tf .nest .map_structure (
355402 lambda box_target : tf .cast (box_target , dtype = dtype ), box_targets )
@@ -427,7 +474,7 @@ def _prefetch_dataset(filename):
427474 return dataset
428475
429476 dataset = dataset .interleave (
430- _prefetch_dataset , num_parallel_calls = tf .data .experimental . AUTOTUNE )
477+ _prefetch_dataset , num_parallel_calls = tf .data .AUTOTUNE )
431478 dataset = dataset .with_options (self .dataset_options )
432479 if self ._is_training :
433480 dataset = dataset .shuffle (64 , seed = seed )
@@ -442,12 +489,12 @@ def _prefetch_dataset(filename):
442489 anchor_labeler , params )
443490 # pylint: enable=g-long-lambda
444491 dataset = dataset .map (
445- map_fn , num_parallel_calls = tf .data .experimental . AUTOTUNE )
492+ map_fn , num_parallel_calls = tf .data .AUTOTUNE )
446493 dataset = dataset .prefetch (batch_size )
447494 dataset = dataset .batch (batch_size , drop_remainder = params ['drop_remainder' ])
448495 dataset = dataset .map (
449496 lambda * args : self .process_example (params , batch_size , * args ))
450- dataset = dataset .prefetch (tf .data .experimental . AUTOTUNE )
497+ dataset = dataset .prefetch (tf .data .AUTOTUNE )
451498 if self ._use_fake_data :
452499 # Turn this dataset into a semi-fake dataset which always loop at the
453500 # first batch. This reduces variance in performance and is useful in
0 commit comments