Skip to content

Commit 455271c

Browse files
committed
fix code style
1 parent 539ab65 commit 455271c

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

efficientdet/dataloader.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,10 @@ def _common_image_process(self, image, classes, boxes, data, params):
271271
from aug import autoaugment # pylint: disable=g-import-not-at-top
272272
if params['autoaugment_policy'] == 'randaug':
273273
image, boxes = autoaugment.distort_image_with_randaugment(
274-
image, boxes, num_layers=1, magnitude=15)
274+
image, boxes, num_layers=1, magnitude=15)
275275
else:
276276
image, boxes = autoaugment.distort_image_with_autoaugment(
277-
image, boxes, params['autoaugment_policy'])
277+
image, boxes, params['autoaugment_policy'])
278278
return image, boxes, classes
279279

280280
def _resize_image_first(self, image, classes, boxes, data, params):
@@ -285,24 +285,26 @@ def _resize_image_first(self, image, classes, boxes, data, params):
285285
input_processor.random_horizontal_flip()
286286

287287
input_processor.set_training_random_scale_factors(
288-
params['jitter_min'], params['jitter_max'],
289-
params.get('target_size', None))
288+
params['jitter_min'], params['jitter_max'],
289+
params.get('target_size', None))
290290
else:
291291
input_processor.set_scale_factors_to_output_size()
292292

293293
image = input_processor.resize_and_crop_image()
294294
boxes, classes = input_processor.resize_and_crop_boxes()
295295

296296
if self._is_training:
297-
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)
297+
image, boxes, classes = self._common_image_process(image, classes,
298+
boxes, data, params)
298299

299300
input_processor.image = image
300301
image = input_processor.normalize_image()
301302
return image, boxes, classes, input_processor.image_scale_to_original
302303

303304
def _resize_image_last(self, image, classes, boxes, data, params):
304305
if self._is_training:
305-
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)
306+
image, boxes, classes = self._common_image_process(image, classes,
307+
boxes, data, params)
306308

307309
input_processor = DetectionInputProcessor(image, params['image_size'],
308310
boxes, classes)
@@ -311,8 +313,8 @@ def _resize_image_last(self, image, classes, boxes, data, params):
311313
input_processor.random_horizontal_flip()
312314

313315
input_processor.set_training_random_scale_factors(
314-
params['jitter_min'], params['jitter_max'],
315-
params.get('target_size', None))
316+
params['jitter_min'], params['jitter_max'],
317+
params.get('target_size', None))
316318
else:
317319
input_processor.set_scale_factors_to_output_size()
318320
input_processor.normalize_image()
@@ -367,15 +369,15 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
367369
areas = data['groundtruth_area']
368370
is_crowds = data['groundtruth_is_crowd']
369371
image_masks = data.get('groundtruth_instance_masks', [])
370-
classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])
371372
source_area = tf.shape(image)[0] * tf.shape(image)[1]
372373
target_size = utils.parse_image_size(params['image_size'])
373374
target_area = target_size[0] * target_size[1]
374375
# set condition in order to always process small
375376
# first which could speed up pipeline
376-
image, boxes, classes, image_scale = tf.cond(source_area > target_area,
377-
lambda: self._resize_image_first(image, classes, boxes, data, params),
378-
lambda: self._resize_image_last(image, classes, boxes, data, params))
377+
image, boxes, classes, image_scale = tf.cond(
378+
source_area > target_area,
379+
lambda: self._resize_image_first(image, classes, boxes, data, params),
380+
lambda: self._resize_image_last(image, classes, boxes, data, params))
379381

380382
# Assign anchors.
381383
(cls_targets, box_targets,
@@ -395,8 +397,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
395397
classes = pad_to_fixed_size(classes, -1,
396398
[self._max_instances_per_image, 1])
397399
if params['mixed_precision']:
398-
dtype = (
399-
tf.keras.mixed_precision.global_policy().compute_dtype)
400+
dtype = tf.keras.mixed_precision.global_policy().compute_dtype
400401
image = tf.cast(image, dtype=dtype)
401402
box_targets = tf.nest.map_structure(
402403
lambda box_target: tf.cast(box_target, dtype=dtype), box_targets)
@@ -460,8 +461,6 @@ def __call__(self, params, input_context=None, batch_size=None):
460461
seed = params['tf_random_seed'] if self._debug else None
461462
dataset = tf.data.Dataset.list_files(
462463
self._file_pattern, shuffle=self._is_training, seed=seed)
463-
if self._is_training:
464-
dataset = dataset.repeat()
465464
if input_context:
466465
dataset = dataset.shard(input_context.num_input_pipelines,
467466
input_context.input_pipeline_id)
@@ -495,6 +494,8 @@ def _prefetch_dataset(filename):
495494
dataset = dataset.map(
496495
lambda *args: self.process_example(params, batch_size, *args))
497496
dataset = dataset.prefetch(tf.data.AUTOTUNE)
497+
if self._is_training:
498+
dataset = dataset.repeat()
498499
if self._use_fake_data:
499500
# Turn this dataset into a semi-fake dataset which always loop at the
500501
# first batch. This reduces variance in performance and is useful in

0 commit comments

Comments
 (0)