Skip to content

Commit 89ff87d

Browse files
authored
Reapply compile in Fabric.setup() by default (#19382)
1 parent af7e79a commit 89ff87d

File tree

2 files changed

+34
-30
lines changed

2 files changed

+34
-30
lines changed

docs/source-fabric/advanced/compile.rst

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically a
220220
221221
Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1.
222222

223+
223224
----
224225

225226

@@ -255,17 +256,33 @@ Naturally, the tradoff here is that it will consume a bit more memory.
255256
256257
You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_.
257258

259+
260+
----
261+
262+
263+
**************************************
264+
A note about torch.compile in practice
265+
**************************************
266+
267+
In practice, you will find that ``torch.compile`` often doesn't work well and can even be counter-productive.
268+
Compilation may fail with cryptic error messages that are impossible to debug without help from the PyTorch team.
269+
It is also not uncommon that ``torch.compile`` will produce a significantly *slower* model or one with much higher memory usage.
270+
On top of that, the compilation phase itself can be incredibly slow, taking several minutes to finish.
271+
For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments.
272+
Always compare the speed and memory usage of the compiled model against the original model!
273+
274+
258275
----
259276

260277

261-
*******************************************************
262-
(Experimental) Apply torch.compile over FSDP, DDP, etc.
263-
*******************************************************
278+
*************************************
279+
Using torch.compile with FSDP and DDP
280+
*************************************
264281

265282
As stated earlier, we recommend that you compile the model before calling ``fabric.setup()``.
266-
However, if you are using DDP or FSDP with Fabric, the compilation won't incorporate the distributed calls inside these wrappers by default.
267-
In an experimental feature, you can let ``fabric.setup()`` reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
268-
In the future, this option will become the default.
283+
In the case of DDP and FSDP, ``fabric.setup()`` will automatically reapply the ``torch.compile`` call after the model gets wrapped in DDP/FSDP internally.
284+
This will ensure that the compilation can incorporate the distributed calls and optimize them.
285+
However, should you have issues compiling DDP and FSDP models, you can opt out of this feature:
269286

270287
.. code-block:: python
271288
@@ -275,25 +292,11 @@ In the future, this option will become the default.
275292
# Compile the model
276293
model = torch.compile(model)
277294
278-
# Default: `fabric.setup()` will not reapply the compilation over DDP/FSDP
279-
model = fabric.setup(model, _reapply_compile=False)
280-
281-
# Recompile the model over DDP/FSDP (experimental)
295+
# Default: `fabric.setup()` will configure compilation over DDP/FSDP for you
282296
model = fabric.setup(model, _reapply_compile=True)
283297
298+
# Turn it off if you see issues with DDP/FSDP
299+
model = fabric.setup(model, _reapply_compile=False)
284300
285-
----
286-
287-
288-
**************************************
289-
A note about torch.compile in practice
290-
**************************************
291-
292-
In practice, you will find that ``torch.compile`` often doesn't work well and can even be counter-productive.
293-
Compilation may fail with cryptic error messages that are impossible to debug without help from the PyTorch team.
294-
It is also not uncommon that ``torch.compile`` will produce a significantly *slower* model or one with much higher memory usage.
295-
On top of that, the compilation phase itself can be incredibly slow, taking several minutes to finish.
296-
For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments.
297-
Always compare the speed and memory usage of the compiled model against the original model!
298301
299302
|

src/lightning/fabric/fabric.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def setup(
214214
module: nn.Module,
215215
*optimizers: Optimizer,
216216
move_to_device: bool = True,
217-
_reapply_compile: Optional[bool] = None,
217+
_reapply_compile: bool = True,
218218
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
219219
r"""Set up a model and its optimizers for accelerated training.
220220
@@ -223,10 +223,11 @@ def setup(
223223
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
224224
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
225225
and alternatively use :meth:`to_device` manually.
226-
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
226+
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
227227
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
228228
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
229-
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
229+
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
230+
issues.
230231
231232
Returns:
232233
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
@@ -280,7 +281,7 @@ def setup(
280281
return module
281282

282283
def setup_module(
283-
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
284+
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: bool = True
284285
) -> _FabricModule:
285286
r"""Set up a model for accelerated training or inference.
286287
@@ -292,11 +293,11 @@ def setup_module(
292293
module: A :class:`torch.nn.Module` to set up
293294
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
294295
and alternatively use :meth:`to_device` manually.
295-
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
296+
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
296297
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
297298
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
298-
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
299-
299+
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
300+
issues.
300301
Returns:
301302
The wrapped model.
302303

0 commit comments

Comments
 (0)