Skip to content

Commit 3fd77cb

Browse files
author
Sean Naren
authored
[docs] Add activation checkpointing information (#9165)
1 parent 0dfc6a1 commit 3fd77cb

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

docs/source/advanced/advanced_gpu.rst

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,16 @@ Below is an example of using both ``wrap`` and ``auto_wrap`` to create your mode
170170
FairScale Activation Checkpointing
171171
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
172172

173-
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed.
173+
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass. They are then re-computed for the backwards pass as needed. Activation checkpointing is very useful when you have intermediate layers that produce large activations.
174174

175175
FairScales' checkpointing wrapper also handles batch norm layers correctly unlike the PyTorch implementation, ensuring stats are tracked correctly due to the multiple forward passes.
176176

177177
This saves memory when training larger models however requires wrapping modules you'd like to use activation checkpointing on. See `here <https://fairscale.readthedocs.io/en/latest/api/nn/misc/checkpoint_activations.html>`__ for more information.
178178

179+
.. warning::
180+
181+
Ensure to not wrap the entire model with activation checkpointing. This is not the intended usage of activation checkpointing, and will lead to failures as seen in `this discussion <https://github.com/PyTorchLightning/pytorch-lightning/discussions/9144>`__.
182+
179183
.. code-block:: python
180184
181185
from pytorch_lightning import Trainer
@@ -185,7 +189,8 @@ This saves memory when training larger models however requires wrapping modules
185189
class MyModel(pl.LightningModule):
186190
def __init__(self):
187191
# Wrap layers using checkpoint_wrapper
188-
self.block = checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
192+
self.block_1 = checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
193+
self.block_2 = nn.Linear(32, 2)
189194
190195
191196
.. _deepspeed:
@@ -515,7 +520,36 @@ DeepSpeed Activation Checkpointing
515520
Activation checkpointing frees activations from memory as soon as they are not needed during the forward pass.
516521
They are then re-computed for the backwards pass as needed.
517522

518-
This saves memory when training larger models however requires using a checkpoint function to run the module as shown below.
523+
Activation checkpointing is very useful when you have intermediate layers that produce large activations.
524+
525+
This saves memory when training larger models, however requires using a checkpoint function to run modules as shown below.
526+
527+
.. warning::
528+
529+
Ensure to not wrap the entire model with activation checkpointing. This is not the intended usage of activation checkpointing, and will lead to failures as seen in `this discussion <https://github.com/PyTorchLightning/pytorch-lightning/discussions/9144>`__.
530+
531+
.. code-block:: python
532+
533+
from pytorch_lightning import Trainer
534+
from pytorch_lightning.plugins import DeepSpeedPlugin
535+
import deepspeed
536+
537+
538+
class MyModel(LightningModule):
539+
...
540+
541+
def __init__(self):
542+
super().__init__()
543+
self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
544+
self.block_2 = torch.nn.Linear(32, 2)
545+
546+
def forward(self, x):
547+
# Use the DeepSpeed checkpointing function instead of calling the module directly
548+
# checkpointing self.layer_h means the activations are deleted after use,
549+
# and re-calculated during the backward passes
550+
x = torch.utils.checkpoint.checkpoint(self.block_1, x)
551+
return self.block_2(x)
552+
519553
520554
.. code-block:: python
521555
@@ -528,12 +562,13 @@ This saves memory when training larger models however requires using a checkpoin
528562
...
529563
530564
def configure_sharded_model(self):
531-
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
565+
self.block_1 = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
566+
self.block_2 = torch.nn.Linear(32, 2)
532567
533568
def forward(self, x):
534569
# Use the DeepSpeed checkpointing function instead of calling the module directly
535-
output = deepspeed.checkpointing.checkpoint(self.block, x)
536-
return output
570+
x = deepspeed.checkpointing.checkpoint(self.block_1, x)
571+
return self.block_2(x)
537572
538573
539574
model = MyModel()

0 commit comments

Comments
 (0)