@@ -256,72 +256,6 @@ def __init__(self,
256256 self ._max_instances_per_image = max_instances_per_image or 100
257257 self ._debug = debug
258258
259- def _common_image_process (self , image , classes , boxes , data , params ):
260- # Training time preprocessing.
261- if params ['skip_crowd_during_training' ]:
262- indices = tf .where (tf .logical_not (data ['groundtruth_is_crowd' ]))
263- classes = tf .gather_nd (classes , indices )
264- boxes = tf .gather_nd (boxes , indices )
265-
266- if params .get ('grid_mask' , None ):
267- from aug import gridmask # pylint: disable=g-import-not-at-top
268- image , boxes = gridmask .gridmask (image , boxes )
269-
270- if params .get ('autoaugment_policy' , None ):
271- from aug import autoaugment # pylint: disable=g-import-not-at-top
272- if params ['autoaugment_policy' ] == 'randaug' :
273- image , boxes = autoaugment .distort_image_with_randaugment (
274- image , boxes , num_layers = 1 , magnitude = 15 )
275- else :
276- image , boxes = autoaugment .distort_image_with_autoaugment (
277- image , boxes , params ['autoaugment_policy' ])
278- return image , boxes , classes
279-
280- def _resize_image_first (self , image , classes , boxes , data , params ):
281- input_processor = DetectionInputProcessor (image , params ['image_size' ],
282- boxes , classes )
283- if self ._is_training :
284- if params ['input_rand_hflip' ]:
285- input_processor .random_horizontal_flip ()
286-
287- input_processor .set_training_random_scale_factors (
288- params ['jitter_min' ], params ['jitter_max' ],
289- params .get ('target_size' , None ))
290- else :
291- input_processor .set_scale_factors_to_output_size ()
292-
293- image = input_processor .resize_and_crop_image ()
294- boxes , classes = input_processor .resize_and_crop_boxes ()
295-
296- if self ._is_training :
297- image , boxes , classes = self ._common_image_process (image , classes ,
298- boxes , data , params )
299-
300- input_processor .image = image
301- image = input_processor .normalize_image ()
302- return image , boxes , classes , input_processor .image_scale_to_original
303-
304- def _resize_image_last (self , image , classes , boxes , data , params ):
305- if self ._is_training :
306- image , boxes , classes = self ._common_image_process (image , classes ,
307- boxes , data , params )
308-
309- input_processor = DetectionInputProcessor (image , params ['image_size' ],
310- boxes , classes )
311- if self ._is_training :
312- if params ['input_rand_hflip' ]:
313- input_processor .random_horizontal_flip ()
314-
315- input_processor .set_training_random_scale_factors (
316- params ['jitter_min' ], params ['jitter_max' ],
317- params .get ('target_size' , None ))
318- else :
319- input_processor .set_scale_factors_to_output_size ()
320- input_processor .normalize_image ()
321- image = input_processor .resize_and_crop_image ()
322- boxes , classes = input_processor .resize_and_crop_boxes ()
323- return image , boxes , classes , input_processor .image_scale_to_original
324-
325259 @tf .autograph .experimental .do_not_convert
326260 def dataset_parser (self , value , example_decoder , anchor_labeler , params ):
327261 """Parse data to a fixed dimension input image and learning targets.
@@ -369,16 +303,41 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
369303 areas = data ['groundtruth_area' ]
370304 is_crowds = data ['groundtruth_is_crowd' ]
371305 image_masks = data .get ('groundtruth_instance_masks' , [])
372- source_area = tf .shape (image )[0 ] * tf .shape (image )[1 ]
373- target_size = utils .parse_image_size (params ['image_size' ])
374- target_area = target_size [0 ] * target_size [1 ]
375- # set condition in order to always process small
376- # first which could speed up pipeline
377- image , boxes , classes , image_scale = tf .cond (
378- source_area > target_area ,
379- lambda : self ._resize_image_first (image , classes , boxes , data , params ),
380- lambda : self ._resize_image_last (image , classes , boxes , data , params ))
381306
307+ if self ._is_training :
308+ # Training time preprocessing.
309+ if params ['skip_crowd_during_training' ]:
310+ indices = tf .where (tf .logical_not (data ['groundtruth_is_crowd' ]))
311+ classes = tf .gather_nd (classes , indices )
312+ boxes = tf .gather_nd (boxes , indices )
313+
314+ if params .get ('grid_mask' , None ):
315+ from aug import gridmask # pylint: disable=g-import-not-at-top
316+ image , boxes = gridmask .gridmask (image , boxes )
317+
318+ if params .get ('autoaugment_policy' , None ):
319+ from aug import autoaugment # pylint: disable=g-import-not-at-top
320+ if params ['autoaugment_policy' ] == 'randaug' :
321+ image , boxes = autoaugment .distort_image_with_randaugment (
322+ image , boxes , num_layers = 1 , magnitude = 15 )
323+ else :
324+ image , boxes = autoaugment .distort_image_with_autoaugment (
325+ image , boxes , params ['autoaugment_policy' ])
326+
327+ input_processor = DetectionInputProcessor (image , params ['image_size' ],
328+ boxes , classes )
329+ input_processor .normalize_image ()
330+ if self ._is_training :
331+ if params ['input_rand_hflip' ]:
332+ input_processor .random_horizontal_flip ()
333+
334+ input_processor .set_training_random_scale_factors (
335+ params ['jitter_min' ], params ['jitter_max' ],
336+ params .get ('target_size' , None ))
337+ else :
338+ input_processor .set_scale_factors_to_output_size ()
339+ image = input_processor .resize_and_crop_image ()
340+ boxes , classes = input_processor .resize_and_crop_boxes ()
382341 # Assign anchors.
383342 (cls_targets , box_targets ,
384343 num_positives ) = anchor_labeler .label_anchors (boxes , classes )
@@ -388,6 +347,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
388347 source_id = tf .strings .to_number (source_id )
389348
390349 # Pad groundtruth data for evaluation.
350+ image_scale = input_processor .image_scale_to_original
391351 boxes *= image_scale
392352 is_crowds = tf .cast (is_crowds , dtype = tf .float32 )
393353 boxes = pad_to_fixed_size (boxes , - 1 , [self ._max_instances_per_image , 4 ])
0 commit comments