Skip to content

Commit e494379

Browse files
authored
fix #889 (#890)
* fix #889 * revert changes and add limits to readme * add limits * update readme
1 parent b640618 commit e494379

File tree

4 files changed

+7
-16
lines changed

4 files changed

+7
-16
lines changed

efficientdet/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,15 +369,15 @@ For more instructions about training on TPUs, please refer to the following tuto
369369

370370
* EfficientNet tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet
371371

372-
## 11. Reducing Memory Usage when Training EfficientDets on GPU.
372+
## 11. Reducing Memory Usage when Training EfficientDets on GPU. (The current approach doesn't support mirrored multi GPU or mixed-precision training)
373373

374374
EfficientDets use a lot of GPU memory for a few reasons:
375375

376376
* Large input resolution: because resolution is one of the scaling dimension, our resolution tends to be higher, which significantly increase activations (although no parameter increase).
377377
* Large internal activations for backbone: our backbone uses a relatively large expansion ratio (6), causing the large expanded activations.
378378
* Deep BiFPN: our BiFPN has multiple top-down and bottom-up paths, which leads to a lot of intermediate memory usage during training.
379379

380-
To train this model on GPU with low memory there is an experimental option gradient_checkpointing.
380+
To train this model on GPU with low memory there is an experimental option grad_checkpoint.
381381

382382
Check these links for a high-level idea of what gradient checkpointing is doing:
383383
1. https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
@@ -387,7 +387,6 @@ Check these links for a high-level idea of what gradient checkpointing is doing:
387387
If set to True, keras model uses ```tf.recompute_grad``` to achieve gradient checkpoints.
388388

389389
Testing shows that:
390-
* It allows to train a d7x network with batch size of 2 by keras/train.py on a 11Gb (1080Ti) GPU
391390
* It also allows to train a d6 network with batch size of 2 by main.py on a 11Gb (1080Ti) GPU
392391

393392
## 12. Visualize TF-Records.

efficientdet/det_model_fn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,6 @@ def model_fn(inputs):
386386

387387
if is_tpu:
388388
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
389-
elif params['mixed_precision']:
390-
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
391-
optimizer)
392389

393390
# Batch norm requires update_ops to be added as a train_op dependency.
394391
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

efficientdet/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ def main(_):
127127
tpu_grpc_url = tpu_cluster_resolver.get_master()
128128
tf.Session.reset(tpu_grpc_url)
129129
else:
130-
# Always enable auto mixed precision graph rewrite
131-
os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
132130
tpu_cluster_resolver = None
133131

134132
# Check data path

efficientdet/utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,14 +618,11 @@ def build_model_with_precision(pp, mm, ii, tt, *args, **kwargs):
618618
outputs = mm(inputs, *args, **kwargs)
619619
set_precision_policy('float32')
620620
elif pp == 'mixed_float16':
621-
if tt:
622-
outputs = mm(ii, *args, **kwargs)
623-
else:
624-
set_precision_policy(pp, loss_scale=tt)
625-
inputs = tf.cast(ii, tf.float16)
626-
with float16_scope():
627-
outputs = mm(inputs, *args, **kwargs)
628-
set_precision_policy('float32')
621+
set_precision_policy(pp, loss_scale=tt)
622+
inputs = tf.cast(ii, tf.float16)
623+
with float16_scope():
624+
outputs = mm(inputs, *args, **kwargs)
625+
set_precision_policy('float32')
629626
elif not pp or pp == 'float32':
630627
outputs = mm(ii, *args, **kwargs)
631628
else:

0 commit comments

Comments
 (0)