Skip to content

Commit 87565cb

Browse files
FP8 + FSDP2 + torch.compile examples for PyTorch Lightning and Fabric (#20440)
* Minimal transformer examples * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tests for compile after fsdp2/tp * Add README's * Add docs * Rename folder, add cross-reference * Fix link * Newline after code-block directive * Update section name * Fix reference * Half standalone tests batch size * Fix integration tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 333d1cf commit 87565cb

File tree

12 files changed

+592
-33
lines changed

12 files changed

+592
-33
lines changed

docs/source-fabric/advanced/compile.rst

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,115 @@ always exclude the first call to ``forward()`` from your measurements, since it
115115
Compile median time: 0.0185 seconds
116116
Speedup: 1.4x
117117
118-
119118
----
120119

120+
**********************************************
121+
Apply torch.compile with ModelParallelStrategy
122+
**********************************************
123+
124+
:func:`torch.compile` can also be invoked as part of the `parallelize_fn` argument of :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy`.
125+
126+
This is particularly handy when :func:`torch.compile` is used in combination with the `torch.distributed.tensor` API.
127+
128+
Here is an example:
129+
130+
.. code-block:: python
131+
132+
import lightning as L
133+
import torch
134+
import torch.nn as nn
135+
import torch.nn.functional as F
136+
from lightning.pytorch.demos import Transformer
137+
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
138+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
139+
from torch.distributed.device_mesh import DeviceMesh
140+
141+
def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
142+
for module in model.modules():
143+
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
144+
fully_shard(module, mesh=device_mesh)
145+
146+
fully_shard(model, mesh=device_mesh)
147+
148+
return torch.compile(model)
149+
150+
def train():
151+
L.seed_everything(42)
152+
153+
with torch.device("meta"):
154+
model = Transformer(
155+
vocab_size=50257,
156+
nlayers=16,
157+
nhid=4096,
158+
ninp=1024,
159+
nhead=32,
160+
)
161+
162+
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize)
163+
164+
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
165+
fabric.launch()
166+
167+
model = fabric.setup(model)
168+
169+
The advantage here is that `parallelize` is called when sharding the model,
170+
so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations.
171+
172+
Also, when using other libraries like `torch ao <https://github.com/pytorch/ao>`_
173+
that need to be applied in a similar fashion, it's easy to reason about the sequence of calls
174+
needed to achieve the equivalent of `compile(distributed(quantized(model)))`:
175+
176+
.. code-block:: python
177+
178+
import lightning as L
179+
import torch
180+
import torch.nn as nn
181+
import torch.nn.functional as F
182+
from lightning.pytorch.demos import Transformer
183+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
184+
from torch.distributed.device_mesh import DeviceMesh
185+
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
186+
187+
def parallelize(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
188+
float8_config = Float8LinearConfig(
189+
pad_inner_dim=True,
190+
)
191+
192+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
193+
return fqn != "decoder"
194+
195+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
196+
197+
for module in model.modules():
198+
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
199+
fully_shard(module, mesh=device_mesh)
200+
201+
fully_shard(model, mesh=device_mesh)
202+
203+
return torch.compile(model)
204+
205+
def train():
206+
L.seed_everything(42)
207+
208+
with torch.device("meta"):
209+
model = Transformer(
210+
vocab_size=50257,
211+
nlayers=16,
212+
nhid=4096,
213+
ninp=1024,
214+
nhead=32,
215+
)
216+
217+
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=parallelize)
218+
219+
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
220+
fabric.launch()
221+
222+
model = fabric.setup(model)
223+
224+
For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/fp8_distributed_transformer>`_.
225+
226+
----
121227

122228
******************
123229
Avoid graph breaks

docs/source-pytorch/advanced/compile.rst

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,122 @@ always exclude the first call to ``forward()``/``*_step()`` from your measuremen
138138
139139
----
140140

141+
**************************************
142+
Apply torch.compile in configure_model
143+
**************************************
144+
145+
:func:`torch.compile` can also be invoked as part of the :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook.
146+
147+
This is particularly handy when :func:`torch.compile` is used in combination with :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`.
148+
149+
Here is an example:
150+
151+
.. code-block:: python
152+
153+
import lightning as L
154+
import torch
155+
import torch.nn as nn
156+
import torch.nn.functional as F
157+
from lightning.pytorch.demos import Transformer
158+
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
159+
from torch.distributed.device_mesh import DeviceMesh
160+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
161+
162+
class LanguageModel(L.LightningModule):
163+
def __init__(self, vocab_size):
164+
super().__init__()
165+
self.vocab_size = vocab_size
166+
self.model = None
167+
168+
def configure_model(self):
169+
if self.model is not None:
170+
return
171+
172+
with torch.device("meta"):
173+
model = Transformer(
174+
vocab_size=self.vocab_size,
175+
nlayers=16,
176+
nhid=4096,
177+
ninp=1024,
178+
nhead=32,
179+
)
180+
181+
for module in model.modules():
182+
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
183+
fully_shard(module, mesh=self.device_mesh)
184+
185+
fully_shard(model, mesh=self.device_mesh)
186+
187+
self.model = torch.compile(model)
188+
189+
def training_step(self, batch):
190+
input, target = batch
191+
output = self.model(input, target)
192+
loss = F.nll_loss(output, target.view(-1))
193+
self.log("train_loss", loss)
194+
return loss
195+
196+
def configure_optimizers(self):
197+
return torch.optim.Adam(self.parameters(), lr=1e-4)
198+
199+
The advantage here is that `configure_model` is called when sharding the model,
200+
so :func:`torch.compile` is guaranteed to run on model shards and capture distributed operations.
201+
202+
Also, when using other libraries like `torch ao <https://github.com/pytorch/ao>`_
203+
that need to be applied in a similar fashion, it's easy to reason about the sequence of calls
204+
needed to achieve the equivalent of `compile(distributed(quantized(model)))`:
205+
206+
.. code-block:: python
207+
208+
import lightning as L
209+
import torch
210+
import torch.nn as nn
211+
import torch.nn.functional as F
212+
from lightning.pytorch.demos import Transformer
213+
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
214+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
215+
from torch.distributed.device_mesh import DeviceMesh
216+
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
217+
218+
class LanguageModel(L.LightningModule):
219+
def __init__(self, vocab_size):
220+
super().__init__()
221+
self.vocab_size = vocab_size
222+
self.model = None
223+
224+
def configure_model(self):
225+
if self.model is not None:
226+
return
227+
228+
with torch.device("meta"):
229+
model = Transformer(
230+
vocab_size=self.vocab_size,
231+
nlayers=16,
232+
nhid=4096,
233+
ninp=1024,
234+
nhead=32,
235+
)
236+
237+
float8_config = Float8LinearConfig(
238+
pad_inner_dim=True,
239+
)
240+
241+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
242+
return fqn != "decoder"
243+
244+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
245+
246+
for module in model.modules():
247+
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
248+
fully_shard(module, mesh=self.device_mesh)
249+
250+
fully_shard(model, mesh=self.device_mesh)
251+
252+
self.model = torch.compile(model)
253+
254+
For a full example, see our `FP8 Distributed Transformer example <https://github.com/Lightning-AI/lightning/blob/master/examples/pytorch/fp8_distributed_transformer>`_.
255+
256+
----
141257

142258
******************
143259
Avoid graph breaks
@@ -253,8 +369,8 @@ Limitations
253369

254370
There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**:
255371

256-
* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment.
257-
This limitation will be lifted in the future.
372+
* The Trainer currently does not reapply ``torch.compile`` over :class:`~lightning.pytorch.strategies.DDPStrategy` and :class:`~lightning.pytorch.strategies.FSDPStrategy`, meaning distributed operations can't benefit from speed ups at the moment.
373+
This limitation can be avoided by using :class:`~lightning.pytorch.strategies.model_parallel.ModelParallelStrategy`, as described in `Apply torch.compile in configure_model`_ above.
258374

259375
* In some cases, using ``self.log()`` in your LightningModule will cause compilation errors.
260376
Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
## Distributed, Low-Precision Transformer Example
2+
3+
This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs.
4+
5+
### Training Large Models and Memory Requirements
6+
7+
One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP).
8+
9+
An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time).
10+
11+
Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40).
12+
13+
The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination:
14+
15+
- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats)
16+
- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch)
17+
18+
Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations.
19+
20+
### Vanilla Transformer Example
21+
22+
This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`.
23+
24+
Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API.
25+
In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2).
26+
27+
The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning.
28+
29+
To execute the code directly just run:
30+
31+
```bash
32+
python train.py
33+
```
34+
35+
### A Note on torch.compile
36+
37+
Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`.
38+
39+
While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchao>=0.7.0
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import lightning as L
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from lightning.fabric.strategies import ModelParallelStrategy
6+
from lightning.pytorch.demos import Transformer, WikiText2
7+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
8+
from torch.distributed.device_mesh import DeviceMesh
9+
from torch.utils.data import DataLoader
10+
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
11+
from tqdm import tqdm
12+
13+
14+
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
15+
float8_config = Float8LinearConfig(
16+
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa
17+
pad_inner_dim=True,
18+
)
19+
20+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
21+
# we skip the decoder because it typically vocabulary size
22+
# is not divisible by 16 as required by float8
23+
return fqn != "decoder"
24+
25+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
26+
27+
for module in model.modules():
28+
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
29+
fully_shard(module, mesh=device_mesh)
30+
31+
fully_shard(model, mesh=device_mesh)
32+
33+
return torch.compile(model)
34+
35+
36+
def train():
37+
L.seed_everything(42)
38+
39+
batch_size = 8
40+
micro_batch_size = 1
41+
42+
max_steps = 100
43+
44+
dataset = WikiText2()
45+
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size)
46+
47+
with torch.device("meta"):
48+
model = Transformer(
49+
vocab_size=dataset.vocab_size,
50+
nlayers=16,
51+
nhid=4096,
52+
ninp=1024,
53+
nhead=32,
54+
)
55+
56+
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model)
57+
58+
fabric = L.Fabric(precision="bf16-true", strategy=strategy)
59+
fabric.launch()
60+
61+
model = fabric.setup(model)
62+
63+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
64+
optimizer = fabric.setup_optimizers(optimizer)
65+
66+
dataloader = fabric.setup_dataloaders(dataloader)
67+
68+
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader)
69+
70+
steps = 0
71+
72+
for i, batch in iterable:
73+
input, target = batch
74+
75+
is_accumulating = i % (batch_size // micro_batch_size) != 0
76+
77+
with fabric.no_backward_sync(model, enabled=is_accumulating):
78+
output = model(input, target)
79+
loss = F.nll_loss(output, target.view(-1))
80+
fabric.backward(loss)
81+
82+
if not is_accumulating:
83+
fabric.clip_gradients(model, optimizer, max_norm=1.0)
84+
optimizer.step()
85+
optimizer.zero_grad()
86+
steps += 1
87+
88+
if fabric.is_global_zero:
89+
iterable.set_postfix_str(f"train_loss={loss.item():.2f}")
90+
91+
if steps == max_steps:
92+
break
93+
94+
fabric.print(torch.cuda.memory_summary())
95+
96+
97+
if __name__ == "__main__":
98+
torch.set_float32_matmul_precision("high")
99+
100+
train()

0 commit comments

Comments
 (0)