Skip to content

Commit c346f4d

Browse files
authored
Compile guide for Fabric (#19330)
1 parent 6421dd8 commit c346f4d

File tree

5 files changed

+321
-4
lines changed

5 files changed

+321
-4
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
#################################
2+
Speed up models by compiling them
3+
#################################
4+
5+
Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs.
6+
This guide shows you how to apply ``torch.compile`` correctly in your code.
7+
8+
.. note::
9+
10+
This requires PyTorch >= 2.0.
11+
12+
13+
----
14+
15+
16+
*********************************
17+
Apply torch.compile to your model
18+
*********************************
19+
20+
Compiling a model in a script together with Fabric is as simple as adding one line of code, calling :func:`torch.compile`:
21+
22+
.. code-block:: python
23+
24+
import torch
25+
import lightning as L
26+
27+
# Set up Fabric
28+
fabric = L.Fabric(devices=1)
29+
30+
# Define the model
31+
model = ...
32+
33+
# Compile the model
34+
model = torch.compile(model)
35+
36+
# `fabric.setup()` should come after `torch.compile()`
37+
model = fabric.setup(model)
38+
39+
40+
.. important::
41+
42+
You should compile the model **before** calling ``fabric.setup()`` as shown above for an optimal integration with features in Fabric.
43+
44+
The newly added call to ``torch.compile()`` by itself doesn't do much. It just wraps the model in a "compiled model".
45+
The actual optimization will start when calling ``forward()`` on the model for the first time:
46+
47+
.. code-block:: python
48+
49+
# 1st execution compiles the model (slow)
50+
output = model(input)
51+
52+
# All future executions will be fast (for inputs of the same size)
53+
output = model(input)
54+
output = model(input)
55+
...
56+
57+
This is important to know when you measure the speed of a compiled model and compare it to a regular model.
58+
You should always *exclude* the first call to ``forward()`` from your measurements, since it includes the compilation time.
59+
60+
.. collapse:: Full example with benchmark
61+
62+
Below is an example that measures the speedup you get when compiling the InceptionV3 from TorchVision.
63+
64+
.. code-block:: python
65+
66+
import statistics
67+
import torch
68+
import torchvision.models as models
69+
import lightning as L
70+
71+
72+
@torch.no_grad()
73+
def benchmark(model, input, num_iters=10):
74+
"""Runs the model on the input several times and returns the median execution time."""
75+
start = torch.cuda.Event(enable_timing=True)
76+
end = torch.cuda.Event(enable_timing=True)
77+
times = []
78+
for _ in range(num_iters):
79+
start.record()
80+
model(input)
81+
end.record()
82+
torch.cuda.synchronize()
83+
times.append(start.elapsed_time(end) / 1000)
84+
return statistics.median(times)
85+
86+
87+
fabric = L.Fabric(accelerator="cuda", devices=1)
88+
89+
model = models.inception_v3()
90+
input = torch.randn(16, 3, 510, 512, device=fabric.device)
91+
92+
# Compile!
93+
compiled_model = torch.compile(model)
94+
95+
# Set up the model with Fabric
96+
model = fabric.setup(model)
97+
compiled_model = fabric.setup(compiled_model)
98+
99+
# warm up the compiled model before we benchmark
100+
compiled_model(input)
101+
102+
# Run multiple forward passes and time them
103+
eager_time = benchmark(model, input)
104+
compile_time = benchmark(compiled_model, input)
105+
106+
# Compare the speedup for the compiled execution
107+
speedup = eager_time / compile_time
108+
print(f"Eager median time: {eager_time:.4f} seconds")
109+
print(f"Compile median time: {compile_time:.4f} seconds")
110+
print(f"Speedup: {speedup:.1f}x")
111+
112+
On an NVIDIA A100 SXM4 40GB with PyTorch 2.2.0, CUDA 12.1, we get the following speedup:
113+
114+
.. code-block:: text
115+
116+
Eager median time: 0.0254 seconds
117+
Compile median time: 0.0185 seconds
118+
Speedup: 1.4x
119+
120+
121+
----
122+
123+
124+
******************
125+
Avoid graph breaks
126+
******************
127+
128+
When ``torch.compile`` looks at the code in your model's ``forward()`` method, it will try to compile as much of the code as possible.
129+
If there are regions in the code that it doesn't understand, it will introduce a so-called "graph break" that essentially splits the code in optimized and unoptimized parts.
130+
Graph breaks aren't a deal breaker, since the optimized parts should still run faster.
131+
But if you want to get the most out of ``torch.compile``, you might want to invest rewriting the problematic section of the code that produce the breaks.
132+
133+
You can check whether your model produces graph breaks by calling ``torch.compile`` with ``fullraph=True``:
134+
135+
.. code-block:: python
136+
137+
# Force an error if there is a graph break in the model
138+
model = torch.compile(model, fullgraph=True)
139+
140+
Be aware that the error messages produced here are often quite cryptic, so you will likely have to do some `troubleshooting <pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`_ to fully optimize your model.
141+
142+
143+
----
144+
145+
146+
*******************
147+
Avoid recompilation
148+
*******************
149+
150+
As mentioned before, the compilation of the model happens the first time you call ``forward()``.
151+
At this point, PyTorch will inspect the input tensor(s) and optimize the compiled code for the particular shape, data type and other properties the input has.
152+
If the shape of the input remains the same across all calls to ``forward()``, PyTorch will reuse the compiled code it generated and you will get the best speedup.
153+
However, if these properties change across subsequent calls to ``forward()``, PyTorch will be forced to recompile the model for the new shapes, and this will significantly slow down your training if it happens on every iteration.
154+
155+
**When your training suddenly becomes slow, it's probably because PyTorch is recompiling the model!**
156+
Here are some common scenarios when this can happen:
157+
158+
- Your Trainer code switches from training to validation/testing and the input shape changes, triggering a recompilation.
159+
- Your dataset size is not divisible by the batch size, and the dataloader has ``drop_last=False`` (the default).
160+
The last batch in your training loop will be smaller and trigger a recompilation.
161+
162+
Ideally, you should try to make the input shape(s) to ``forward()`` static.
163+
However, when this is not possible, you can request PyTorch to compile the code by taking into account possible changes to the input shapes.
164+
165+
.. code-block:: python
166+
167+
# On PyTorch < 2.2
168+
model = torch.compile(model, dynamic=True)
169+
170+
A model compiled with ``dynamic=True`` will typically be slower than a model compiled with static shapes, but it will avoid the extreme cost of recompilation every iteration.
171+
On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically and you should no longer need to set this.
172+
173+
.. collapse:: Example with dynamic shapes
174+
175+
The code below shows an example where the model recompiles for several seconds because the input shape changed.
176+
You can compare the timing results by toggling ``dynamic=True/False`` in the call to ``torch.compile``:
177+
178+
.. code-block:: python
179+
180+
import time
181+
import torch
182+
import torchvision.models as models
183+
import lightning as L
184+
185+
fabric = L.Fabric(accelerator="cuda", devices=1)
186+
187+
model = models.inception_v3()
188+
189+
# dynamic=False is the default
190+
torch._dynamo.config.automatic_dynamic_shapes = False
191+
192+
compiled_model = torch.compile(model)
193+
compiled_model = fabric.setup(compiled_model)
194+
195+
input = torch.randn(16, 3, 512, 512, device=fabric.device)
196+
t0 = time.time()
197+
compiled_model(input)
198+
torch.cuda.synchronize()
199+
print(f"1st forward: {time.time() - t0:.2f} seconds.")
200+
201+
input = torch.randn(8, 3, 512, 512, device=fabric.device) # note the change in shape
202+
t0 = time.time()
203+
compiled_model(input)
204+
torch.cuda.synchronize()
205+
print(f"2nd forward: {time.time() - t0:.2f} seconds.")
206+
207+
With ``automatic_dynamic_shapes=True``:
208+
209+
.. code-block:: text
210+
211+
1st forward: 41.90 seconds.
212+
2nd forward: 89.27 seconds.
213+
214+
With ``automatic_dynamic_shapes=False``:
215+
216+
.. code-block:: text
217+
218+
1st forward: 42.12 seconds.
219+
2nd forward: 47.77 seconds.
220+
221+
Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1.
222+
223+
----
224+
225+
226+
***********************************
227+
Experiment with compilation options
228+
***********************************
229+
230+
There are optional settings that, depending on your model, can give additional speedups.
231+
232+
**CUDA Graphs:** By enabling CUDA Graphs, CUDA will record all computations in a graph and replay it every time forward and backward is called.
233+
The requirement is that your model must be static, i.e., the input shape must not change and your model must execute the same operations every time.
234+
Enabling CUDA Graphs often results in a significant speedup, but sometimes also increases the memory usage of your model.
235+
236+
.. code-block:: python
237+
238+
# Enable CUDA Graphs
239+
compiled_model = torch.compile(model, mode="reduce-overhead")
240+
241+
# This does the same
242+
compiled_model = torch.compile(model, options={"triton.cudagraphs": True})
243+
244+
|
245+
246+
**Shape padding:** The specific shape/size of the tensors involved in the computation of your model (input, activations, weights, gradients, etc.) can have an impact on the performance.
247+
With shape padding enabled, ``torch.compile`` can extend the tensors by padding to a size that gives a better memory alignment.
248+
Naturally, the tradoff here is that it will consume a bit more memory.
249+
250+
.. code-block:: python
251+
252+
# Default is False
253+
compiled_model = torch.compile(model, options={"shape_padding": True})
254+
255+
256+
You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_.
257+
258+
----
259+
260+
261+
*******************************************************
262+
(Experimental) Apply torch.compile over FSDP, DDP, etc.
263+
*******************************************************
264+
265+
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
266+
However, if you are using DDP or FSDP with Fabric, the compilation won't incorporate the distributed calls inside these wrappers by default.
267+
In an experimental feature, you can let ``fabric.setup()`` reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
268+
In the future, this option will become the default.
269+
270+
.. code-block:: python
271+
272+
# Choose a distributed strategy like DDP or FSDP
273+
fabric = L.Fabric(devices=2, strategy="ddp")
274+
275+
# Compile the model
276+
model = torch.compile(model)
277+
278+
# Default: `fabric.setup()` will not reapply the compilation over DDP/FSDP
279+
model = fabric.setup(model, _reapply_compile=False)
280+
281+
# Recompile the model over DDP/FSDP (experimental)
282+
model = fabric.setup(model, _reapply_compile=True)
283+
284+
285+
----
286+
287+
288+
**************************************
289+
A note about torch.compile in practice
290+
**************************************
291+
292+
In practice, you will find that ``torch.compile`` often doesn't work well and can even be counter-productive.
293+
Compilation may fail with cryptic error messages that are impossible to debug without help from the PyTorch team.
294+
It is also not uncommon that ``torch.compile`` will produce a significantly *slower* model or one with much higher memory usage.
295+
On top of that, the compilation phase itself can be incredibly slow, taking several minutes to finish.
296+
For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments.
297+
Always compare the speed and memory usage of the compiled model against the original model!
298+
299+
|

docs/source-fabric/glossary/index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ Glossary
6969
:button_link: ../advanced/distributed_communication.html
7070
:col_css: col-md-4
7171

72+
.. displayitem::
73+
:header: Compile
74+
:button_link: ../advanced/compile.html
75+
:col_css: col-md-4
76+
7277
.. displayitem::
7378
:header: CUDA
7479
:button_link: ../fundamentals/accelerators.html

docs/source-fabric/guide/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ Advanced Topics
157157
:height: 160
158158
:tag: advanced
159159

160+
.. displayitem::
161+
:header: Speed up models by compiling them
162+
:description: Use torch.compile to speed up models on modern hardware
163+
:button_link: ../advanced/compile.html
164+
:col_css: col-md-4
165+
:height: 150
166+
:tag: advanced
167+
160168
.. displayitem::
161169
:header: Train models with billions of parameters
162170
:description: Train the largest models with FSDP across multiple GPUs and machines

docs/source-fabric/index.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ Get Started
113113
<div class="display-card-container">
114114
<div class="row">
115115

116-
.. Add callout items below this line
117-
118116
.. displayitem::
119117
:header: Convert to Fabric in 5 minutes
120118
:description: Learn how to add Fabric to your PyTorch code
@@ -168,8 +166,6 @@ Get Started
168166
</div>
169167
</div>
170168

171-
.. End of callout item section
172-
173169
|
174170
|
175171

docs/source-fabric/levels/advanced.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
<../advanced/gradient_accumulation>
66
<../advanced/distributed_communication>
77
<../advanced/multiple_setup>
8+
<../advanced/compile>
89
<../advanced/model_parallel/fsdp>
910
<../guide/checkpoint/distributed_checkpoint>
1011

@@ -42,6 +43,14 @@ Advanced skills
4243
:height: 170
4344
:tag: advanced
4445

46+
.. displayitem::
47+
:header: Speed up models by compiling them
48+
:description: Use torch.compile to speed up models on modern hardware
49+
:button_link: ../advanced/compile.html
50+
:col_css: col-md-4
51+
:height: 170
52+
:tag: advanced
53+
4554
.. displayitem::
4655
:header: Train models with billions of parameters
4756
:description: Train the largest models with FSDP across multiple GPUs and machines

0 commit comments

Comments
 (0)