Skip to content

Commit 9bc5774

Browse files
authored
Add test for reapply_compile with FSDP on gpu
1 parent c74bdab commit 9bc5774

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pytest
1313
import torch
1414
import torch.nn as nn
15+
from torch._dynamo import OptimizedModule
1516
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
1617
from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap
1718
from torchmetrics import Accuracy
@@ -971,3 +972,30 @@ def configure_optimizers(self):
971972
max_steps=4,
972973
)
973974
trainer.fit(model, ckpt_path=checkpoint_path_full)
975+
976+
977+
@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
978+
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
979+
@mock.patch.dict(os.environ, {})
980+
def test_reapply_compile():
981+
"""Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
982+
trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False)
983+
984+
model = BoringModel()
985+
compile_kwargs = {"mode": "reduce-overhead"}
986+
compiled_model = torch.compile(model, **compile_kwargs)
987+
torch.compile.reset_mock()
988+
989+
trainer.fit(compiled_model)
990+
trainer_model = trainer.strategy.model
991+
992+
assert isinstance(trainer_model, OptimizedModule)
993+
assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel)
994+
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
995+
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)
996+
997+
assert trainer_model._orig_mod.module == model
998+
999+
# Smoke-testing forward to ensure we don't get compilation errors
1000+
for _ in range(3):
1001+
trainer_model(torch.randn(2, 32, device="gpu")).sum().backward()

0 commit comments

Comments
 (0)