|
12 | 12 | import pytest
|
13 | 13 | import torch
|
14 | 14 | import torch.nn as nn
|
| 15 | +from torch._dynamo import OptimizedModule |
15 | 16 | from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
16 | 17 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap
|
17 | 18 | from torchmetrics import Accuracy
|
@@ -971,3 +972,30 @@ def configure_optimizers(self):
|
971 | 972 | max_steps=4,
|
972 | 973 | )
|
973 | 974 | trainer.fit(model, ckpt_path=checkpoint_path_full)
|
| 975 | + |
| 976 | + |
| 977 | +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) |
| 978 | +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) |
| 979 | +@mock.patch.dict(os.environ, {}) |
| 980 | +def test_reapply_compile(): |
| 981 | + """Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper.""" |
| 982 | + trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False) |
| 983 | + |
| 984 | + model = BoringModel() |
| 985 | + compile_kwargs = {"mode": "reduce-overhead"} |
| 986 | + compiled_model = torch.compile(model, **compile_kwargs) |
| 987 | + torch.compile.reset_mock() |
| 988 | + |
| 989 | + trainer.fit(compiled_model) |
| 990 | + trainer_model = trainer.strategy.model |
| 991 | + |
| 992 | + assert isinstance(trainer_model, OptimizedModule) |
| 993 | + assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel) |
| 994 | + # Assert we called compile again with the same arguments, but on the FSDP-wrapped module |
| 995 | + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) |
| 996 | + |
| 997 | + assert trainer_model._orig_mod.module == model |
| 998 | + |
| 999 | + # Smoke-testing forward to ensure we don't get compilation errors |
| 1000 | + for _ in range(3): |
| 1001 | + trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() |
0 commit comments