Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1bc2ce7
Add compile_fn for Trainer
mieshkiwrk Sep 10, 2024
e26132a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
925c376
Add parameter description
mieshkiwrk Sep 10, 2024
1946070
Merge branch 'master' into feature/trainer-compile-fn
lantiga Nov 12, 2024
86d2c70
Test reapply_compile for trainer
mieshkiwrk Nov 27, 2024
8db1a6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
f0f0a57
Remove reapply_compile flag
mieshkiwrk Dec 2, 2024
2c498f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 2, 2024
bbcaad1
Dict -> dict
mieshkiwrk Dec 2, 2024
f2c436a
Merge branch 'master' into feature/trainer-compile-fn
mieshkiwrk Dec 2, 2024
809c6c4
Test trainer rewrap compiled module over DDP strategy
mieshkiwrk Jan 10, 2025
d3645c4
Merge branch 'master' into feature/trainer-compile-fn
mieshkiwrk Jan 10, 2025
c74bdab
Run DDP test_reapply_compile on gpu
mieshkiwrk Jan 10, 2025
9bc5774
Add test for reapply_compile with FSDP on gpu
mieshkiwrk Jan 10, 2025
87c1377
Update test_ddp_integration.py
mieshkiwrk Jan 10, 2025
b17a3dc
Remove not used tmp_path argument
mieshkiwrk Jan 10, 2025
8e73a21
test_trainer_compiled_model change
mieshkiwrk Feb 5, 2025
fc59439
Merge branch 'master' into feature/trainer-compile-fn
mieshkiwrk Feb 5, 2025
f7aad69
Merge branch 'master' into feature/trainer-compile-fn
mieshkiwrk Feb 17, 2025
20c5353
Merge branch 'master' into feature/trainer-compile-fn
Borda Mar 13, 2025
0096755
Merge branch 'master' into feature/trainer-compile-fn
mieshkiwrk May 8, 2025
a0f79ca
Merge branch 'master' into feature/trainer-compile-fn
Borda Aug 19, 2025
8a84d07
Merge branch 'master' into feature/trainer-compile-fn
Borda Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
from lightning.fabric.utilities.types import _PATH
from lightning.fabric.wrappers import _to_compiled, _unwrap_compiled
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar
from lightning.pytorch.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -565,20 +566,26 @@ def fit(
For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`.

"""
model = _maybe_unwrap_optimized(model)
# when provided compiled model, unwrap and re-do after applied strategy
model, compile_kwargs = (
_unwrap_compiled(model)
if isinstance(model, torch._dynamo.OptimizedModule)
else (_maybe_unwrap_optimized(model), None)
)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
self.should_stop = False
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)

def _fit_impl(
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
Expand Down Expand Up @@ -610,7 +617,7 @@ def _fit_impl(
model_provided=True,
model_connected=self.lightning_module is not None,
)
self._run(model, ckpt_path=ckpt_path)
self._run(model, compile_kwargs, ckpt_path=ckpt_path)

assert self.state.stopped
self.training = False
Expand Down Expand Up @@ -947,7 +954,10 @@ def _predict_impl(
return results

def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
self,
model: "pl.LightningModule",
compile_kwargs: Optional[dict[str, Any]] = None,
ckpt_path: Optional[_PATH] = None,
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
Expand Down Expand Up @@ -1001,6 +1011,10 @@ def _run(
# strategy will configure model and move it to the device
self.strategy.setup(self)

# when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy
if compile_kwargs is not None:
self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs)

# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
import torch
from torch._dynamo import OptimizedModule
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.multiprocessing import ProcessRaisedException
from torch.nn.parallel.distributed import DistributedDataParallel
Expand Down Expand Up @@ -448,3 +449,30 @@ def creates_processes_externally(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp", max_steps=2, logger=False)

model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

trainer.fit(compiled_model)
trainer_model = trainer.strategy.model

assert isinstance(trainer_model, OptimizedModule)
assert isinstance(trainer_model._orig_mod, DistributedDataParallel)
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)

assert trainer_model._orig_mod.module == model

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
trainer_model(torch.randn(2, 32, device="gpu")).sum().backward()
28 changes: 28 additions & 0 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest
import torch
import torch.nn as nn
from torch._dynamo import OptimizedModule
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy
Expand Down Expand Up @@ -974,3 +975,30 @@ def configure_optimizers(self):
max_steps=4,
)
trainer.fit(model, ckpt_path=checkpoint_path_full)


@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False)

model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()

trainer.fit(compiled_model)
trainer_model = trainer.strategy.model

assert isinstance(trainer_model, OptimizedModule)
assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel)
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)

assert trainer_model._orig_mod.module == model

# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
trainer_model(torch.randn(2, 32, device="gpu")).sum().backward()
8 changes: 2 additions & 6 deletions tests/tests_pytorch/utilities/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,14 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):

model = BoringModel()
compiled_model = torch.compile(model)
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference

# can train with compiled model
trainer = Trainer(**trainer_kwargs)
trainer.fit(compiled_model)
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
assert isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule)

# the compiled model can be uncompiled
to_uncompiled_model = to_uncompiled(compiled_model)
assert model._compiler_ctx is None
assert compiled_model._compiler_ctx is None
assert to_uncompiled_model._compiler_ctx is None

# the compiled model needs to be passed
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
Expand All @@ -66,7 +62,7 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
# the uncompiled model can be fitted
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
assert trainer.model._compiler_ctx is None
assert not isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule)

# some strategies do not support it
if RequirementCache("deepspeed"):
Expand Down
Loading