Skip to content

Commit 76a4075

Browse files
authored
UlyssesSP: TiledMLP doc - recomputes forward twice (#7664)
Make it very clear that `TiledMLP`'s memory saving has a cost of recomputing forward.
1 parent 3e64f49 commit 76a4075

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

deepspeed/runtime/sequence_parallel/ulysses_sp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ class SequenceTiledCompute(torch.autograd.Function):
670670
"""
671671
A generic autograd function to perform a tiled compute.
672672
673+
Please note this module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. And if you're using activation checkpointing it then occurs trice.
674+
673675
Please note that this implementation doesn't require DeepSpeed and can work without it. `compute_params` can remain `None` in such a case.
674676
675677
For an easier to understand example see TiledMLP - which is the same as this autograd function but without the generalization code.
@@ -835,9 +837,11 @@ def backward(ctx, *grads) -> torch.Tensor:
835837

836838
class TiledMLP(torch.autograd.Function):
837839
"""
838-
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP when using very long sequence lengths
840+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP when using very long sequence lengths.
841+
842+
Please note this module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration. And if you're using activation checkpointing it then occurs trice.
839843
840-
For a general tiled compute implementation that can handle any `forward` see `SequenceTiledCompute`
844+
For a general tiled compute implementation that can handle any `forward` see `SequenceTiledCompute`.
841845
842846
Args:
843847
- fn: the function to call on sharded inputs

0 commit comments

Comments
 (0)