Skip to content

Commit 539ab65

Browse files
authored
update for tf2.4 (#908)
* update for tf2.4 * fix mixed precision with recompute gradient * update README * fix multi gpus training * update README * fix LossScaleOptimizer bug * disable steps_per_execution in default * split all reduce
1 parent 53753bb commit 539ab65

File tree

13 files changed

+170
-86
lines changed

13 files changed

+170
-86
lines changed

efficientdet/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ For more instructions about training on TPUs, please refer to the following tuto
369369

370370
* EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet
371371

372-
## 11. Reducing Memory Usage when Training EfficientDets on GPU. (The current approach doesn't support mirrored multi GPU or mixed-precision training)
372+
## 11. Reducing Memory Usage when Training EfficientDets on GPU.
373373

374374
EfficientDets use a lot of GPU memory for a few reasons:
375375

efficientdet/dataloader.py

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,14 @@ def __init__(self, image, output_size):
4848
self._crop_offset_y = tf.constant(0)
4949
self._crop_offset_x = tf.constant(0)
5050

51+
@property
52+
def image(self):
53+
return self._image
54+
55+
@image.setter
56+
def image(self, image):
57+
self._image = image
58+
5159
def normalize_image(self):
5260
"""Normalize the image to zero mean and unit variance."""
5361
# The image normalization is identical to Cloud TPU ResNet.
@@ -61,6 +69,7 @@ def normalize_image(self):
6169
scale = tf.expand_dims(scale, axis=0)
6270
scale = tf.expand_dims(scale, axis=0)
6371
self._image /= scale
72+
return self._image
6473

6574
def set_training_random_scale_factors(self,
6675
scale_min,
@@ -126,6 +135,7 @@ def set_scale_factors_to_output_size(self):
126135

127136
def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
128137
"""Resize input image and crop it to the self._output dimension."""
138+
dtype = self._image.dtype
129139
scaled_image = tf.image.resize(
130140
self._image, [self._scaled_height, self._scaled_width], method=method)
131141
scaled_image = scaled_image[self._crop_offset_y:self._crop_offset_y +
@@ -135,7 +145,8 @@ def resize_and_crop_image(self, method=tf.image.ResizeMethod.BILINEAR):
135145
output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
136146
self._output_size[0],
137147
self._output_size[1])
138-
return output_image
148+
self._image = tf.cast(output_image, dtype)
149+
return self._image
139150

140151

141152
class DetectionInputProcessor(InputProcessor):
@@ -245,6 +256,70 @@ def __init__(self,
245256
self._max_instances_per_image = max_instances_per_image or 100
246257
self._debug = debug
247258

259+
def _common_image_process(self, image, classes, boxes, data, params):
260+
# Training time preprocessing.
261+
if params['skip_crowd_during_training']:
262+
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
263+
classes = tf.gather_nd(classes, indices)
264+
boxes = tf.gather_nd(boxes, indices)
265+
266+
if params.get('grid_mask', None):
267+
from aug import gridmask # pylint: disable=g-import-not-at-top
268+
image, boxes = gridmask.gridmask(image, boxes)
269+
270+
if params.get('autoaugment_policy', None):
271+
from aug import autoaugment # pylint: disable=g-import-not-at-top
272+
if params['autoaugment_policy'] == 'randaug':
273+
image, boxes = autoaugment.distort_image_with_randaugment(
274+
image, boxes, num_layers=1, magnitude=15)
275+
else:
276+
image, boxes = autoaugment.distort_image_with_autoaugment(
277+
image, boxes, params['autoaugment_policy'])
278+
return image, boxes, classes
279+
280+
def _resize_image_first(self, image, classes, boxes, data, params):
281+
input_processor = DetectionInputProcessor(image, params['image_size'],
282+
boxes, classes)
283+
if self._is_training:
284+
if params['input_rand_hflip']:
285+
input_processor.random_horizontal_flip()
286+
287+
input_processor.set_training_random_scale_factors(
288+
params['jitter_min'], params['jitter_max'],
289+
params.get('target_size', None))
290+
else:
291+
input_processor.set_scale_factors_to_output_size()
292+
293+
image = input_processor.resize_and_crop_image()
294+
boxes, classes = input_processor.resize_and_crop_boxes()
295+
296+
if self._is_training:
297+
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)
298+
299+
input_processor.image = image
300+
image = input_processor.normalize_image()
301+
return image, boxes, classes, input_processor.image_scale_to_original
302+
303+
def _resize_image_last(self, image, classes, boxes, data, params):
304+
if self._is_training:
305+
image, boxes, classes = self._common_image_process(image, classes, boxes, data, params)
306+
307+
input_processor = DetectionInputProcessor(image, params['image_size'],
308+
boxes, classes)
309+
if self._is_training:
310+
if params['input_rand_hflip']:
311+
input_processor.random_horizontal_flip()
312+
313+
input_processor.set_training_random_scale_factors(
314+
params['jitter_min'], params['jitter_max'],
315+
params.get('target_size', None))
316+
else:
317+
input_processor.set_scale_factors_to_output_size()
318+
input_processor.normalize_image()
319+
image = input_processor.resize_and_crop_image()
320+
boxes, classes = input_processor.resize_and_crop_boxes()
321+
return image, boxes, classes, input_processor.image_scale_to_original
322+
248323
@tf.autograph.experimental.do_not_convert
249324
def dataset_parser(self, value, example_decoder, anchor_labeler, params):
250325
"""Parse data to a fixed dimension input image and learning targets.
@@ -293,41 +368,14 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
293368
is_crowds = data['groundtruth_is_crowd']
294369
image_masks = data.get('groundtruth_instance_masks', [])
295370
classes = tf.reshape(tf.cast(classes, dtype=tf.float32), [-1, 1])
296-
297-
if self._is_training:
298-
# Training time preprocessing.
299-
if params['skip_crowd_during_training']:
300-
indices = tf.where(tf.logical_not(data['groundtruth_is_crowd']))
301-
classes = tf.gather_nd(classes, indices)
302-
boxes = tf.gather_nd(boxes, indices)
303-
304-
if params.get('grid_mask', None):
305-
from aug import gridmask # pylint: disable=g-import-not-at-top
306-
image, boxes = gridmask.gridmask(image, boxes)
307-
308-
if params.get('autoaugment_policy', None):
309-
from aug import autoaugment # pylint: disable=g-import-not-at-top
310-
if params['autoaugment_policy'] == 'randaug':
311-
image, boxes = autoaugment.distort_image_with_randaugment(
312-
image, boxes, num_layers=1, magnitude=15)
313-
else:
314-
image, boxes = autoaugment.distort_image_with_autoaugment(
315-
image, boxes, params['autoaugment_policy'])
316-
317-
input_processor = DetectionInputProcessor(image, params['image_size'],
318-
boxes, classes)
319-
input_processor.normalize_image()
320-
if self._is_training:
321-
if params['input_rand_hflip']:
322-
input_processor.random_horizontal_flip()
323-
324-
input_processor.set_training_random_scale_factors(
325-
params['jitter_min'], params['jitter_max'],
326-
params.get('target_size', None))
327-
else:
328-
input_processor.set_scale_factors_to_output_size()
329-
image = input_processor.resize_and_crop_image()
330-
boxes, classes = input_processor.resize_and_crop_boxes()
371+
source_area = tf.shape(image)[0] * tf.shape(image)[1]
372+
target_size = utils.parse_image_size(params['image_size'])
373+
target_area = target_size[0] * target_size[1]
374+
# set condition in order to always process small
375+
# first which could speed up pipeline
376+
image, boxes, classes, image_scale = tf.cond(source_area > target_area,
377+
lambda: self._resize_image_first(image, classes, boxes, data, params),
378+
lambda: self._resize_image_last(image, classes, boxes, data, params))
331379

332380
# Assign anchors.
333381
(cls_targets, box_targets,
@@ -338,7 +386,6 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
338386
source_id = tf.strings.to_number(source_id)
339387

340388
# Pad groundtruth data for evaluation.
341-
image_scale = input_processor.image_scale_to_original
342389
boxes *= image_scale
343390
is_crowds = tf.cast(is_crowds, dtype=tf.float32)
344391
boxes = pad_to_fixed_size(boxes, -1, [self._max_instances_per_image, 4])
@@ -349,7 +396,7 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
349396
[self._max_instances_per_image, 1])
350397
if params['mixed_precision']:
351398
dtype = (
352-
tf.keras.mixed_precision.experimental.global_policy().compute_dtype)
399+
tf.keras.mixed_precision.global_policy().compute_dtype)
353400
image = tf.cast(image, dtype=dtype)
354401
box_targets = tf.nest.map_structure(
355402
lambda box_target: tf.cast(box_target, dtype=dtype), box_targets)
@@ -427,7 +474,7 @@ def _prefetch_dataset(filename):
427474
return dataset
428475

429476
dataset = dataset.interleave(
430-
_prefetch_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
477+
_prefetch_dataset, num_parallel_calls=tf.data.AUTOTUNE)
431478
dataset = dataset.with_options(self.dataset_options)
432479
if self._is_training:
433480
dataset = dataset.shuffle(64, seed=seed)
@@ -442,12 +489,12 @@ def _prefetch_dataset(filename):
442489
anchor_labeler, params)
443490
# pylint: enable=g-long-lambda
444491
dataset = dataset.map(
445-
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
492+
map_fn, num_parallel_calls=tf.data.AUTOTUNE)
446493
dataset = dataset.prefetch(batch_size)
447494
dataset = dataset.batch(batch_size, drop_remainder=params['drop_remainder'])
448495
dataset = dataset.map(
449496
lambda *args: self.process_example(params, batch_size, *args))
450-
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
497+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
451498
if self._use_fake_data:
452499
# Turn this dataset into a semi-fake dataset which always loop at the
453500
# first batch. This reduces variance in performance and is useful in

efficientdet/det_model_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def model_fn(inputs):
341341

342342
precision = utils.get_precision(params['strategy'], params['mixed_precision'])
343343
cls_outputs, box_outputs = utils.build_model_with_precision(
344-
precision, model_fn, features, params['is_training_bn'])
344+
precision, model_fn, features)
345345

346346
levels = cls_outputs.keys()
347347
for level in levels:

efficientdet/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def model_arch(feats, model_name=None, **kwargs):
159159
model_arch = det_model_fn.get_model_arch(model_name)
160160

161161
cls_outputs, box_outputs = utils.build_model_with_precision(
162-
precision, model_arch, inputs, False, model_name, **kwargs)
162+
precision, model_arch, inputs, model_name, **kwargs)
163163

164164
if mixed_precision:
165165
# Post-processing has multiple places with hard-coded float32.

efficientdet/keras/efficientdet_keras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from keras import tfmot
2929
from keras import util_keras
3030
# pylint: disable=arguments-differ # fo keras layers.
31-
31+
utils.BatchNormalization = util_keras.get_batch_norm(tf.keras.layers.BatchNormalization)
32+
utils.SyncBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
33+
utils.TpuBatchNormalization = util_keras.get_batch_norm(tf.keras.layers.experimental.SyncBatchNormalization)
3234

3335
def add_n(nodes):
3436
"""A customized add_n to add up a list of tensors."""

efficientdet/keras/infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def main(_):
5353
config.override(FLAGS.hparams)
5454

5555
# Use 'mixed_float16' if running on GPUs.
56-
policy = tf.keras.mixed_precision.experimental.Policy('float32')
57-
tf.keras.mixed_precision.experimental.set_policy(policy)
58-
tf.config.experimental_run_functions_eagerly(FLAGS.debug)
56+
policy = tf.keras.mixed_precision.Policy('float32')
57+
tf.keras.mixed_precision.set_global_policy(policy)
58+
tf.config.run_functions_eagerly(FLAGS.debug)
5959

6060
# Create and run the model.
6161
model = efficientdet_keras.EfficientDetModel(config=config)

efficientdet/keras/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def __init__(self,
179179
mixed_precision = self.params.get('mixed_precision', None)
180180
precision = utils.get_precision(
181181
self.params.get('strategy', None), mixed_precision)
182-
policy = tf.keras.mixed_precision.experimental.Policy(precision)
183-
tf.keras.mixed_precision.experimental.set_policy(policy)
182+
policy = tf.keras.mixed_precision.Policy(precision)
183+
tf.keras.mixed_precision.set_global_policy(policy)
184184

185185
@property
186186
def model(self):

efficientdet/keras/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def define_flags():
7878
flags.DEFINE_integer('batch_size', 64, 'training batch size')
7979
flags.DEFINE_integer('eval_samples', 5000, 'The number of samples for '
8080
'evaluation.')
81-
flags.DEFINE_integer('steps_per_execution', 200,
81+
flags.DEFINE_integer('steps_per_execution', 1,
8282
'Number of steps per training execution.')
8383
flags.DEFINE_string(
8484
'train_file_pattern', None,
@@ -163,7 +163,7 @@ def main(_):
163163
tf.config.experimental.set_memory_growth(gpu, True)
164164

165165
if FLAGS.debug:
166-
tf.config.experimental_run_functions_eagerly(True)
166+
tf.config.run_functions_eagerly(True)
167167
tf.debugging.set_log_device_placement(True)
168168
os.environ['TF_DETERMINISTIC_OPS'] = '1'
169169
tf.random.set_seed(FLAGS.tf_random_seed)
@@ -202,8 +202,8 @@ def main(_):
202202
config.override(params, True)
203203
# set mixed precision policy by keras api.
204204
precision = utils.get_precision(config.strategy, config.mixed_precision)
205-
policy = tf.keras.mixed_precision.experimental.Policy(precision)
206-
tf.keras.mixed_precision.experimental.set_policy(policy)
205+
policy = tf.keras.mixed_precision.Policy(precision)
206+
tf.keras.mixed_precision.set_global_policy(policy)
207207

208208
def get_dataset(is_training, config):
209209
file_pattern = (

efficientdet/keras/train_lib.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,9 @@ def get_optimizer(params):
310310
optimizer, average_decay=moving_average_decay, dynamic_decay=True)
311311
precision = utils.get_precision(params['strategy'], params['mixed_precision'])
312312
if precision == 'mixed_float16' and params['loss_scale']:
313-
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
313+
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
314314
optimizer,
315-
loss_scale=tf.mixed_precision.experimental.DynamicLossScale(
316-
params['loss_scale']))
315+
initial_scale=params['loss_scale'])
317316
return optimizer
318317

319318

@@ -777,17 +776,18 @@ def train_step(self, data):
777776
loss_vals['reg_l2_loss'] = reg_l2_loss
778777
total_loss += tf.cast(reg_l2_loss, loss_dtype)
779778
if isinstance(self.optimizer,
780-
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
779+
tf.keras.mixed_precision.LossScaleOptimizer):
781780
scaled_loss = self.optimizer.get_scaled_loss(total_loss)
781+
optimizer = self.optimizer.inner_optimizer
782782
else:
783783
scaled_loss = total_loss
784+
optimizer = self.optimizer
784785
loss_vals['loss'] = total_loss
785-
loss_vals['learning_rate'] = self.optimizer.learning_rate(
786-
self.optimizer.iterations)
786+
loss_vals['learning_rate'] = optimizer.learning_rate(optimizer.iterations)
787787
trainable_vars = self._freeze_vars()
788788
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
789789
if isinstance(self.optimizer,
790-
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
790+
tf.keras.mixed_precision.LossScaleOptimizer):
791791
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
792792
else:
793793
gradients = scaled_gradients

efficientdet/keras/util_keras.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,10 @@ def fp16_to_fp32_nested(input_nested):
174174
else:
175175
return input_nested
176176
return out_tensor_dict
177+
178+
def get_batch_norm(bn_class):
179+
def _wrapper(*args, **kwargs):
180+
if not kwargs.get('name', None):
181+
kwargs['name'] = 'tpu_batch_normalization'
182+
return bn_class(*args, **kwargs)
183+
return _wrapper

0 commit comments

Comments
 (0)