Skip to content

Commit 809c6c4

Browse files
authored
Test trainer rewrap compiled module over DDP strategy
1 parent f2c436a commit 809c6c4

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/tests_pytorch/strategies/test_ddp_integration.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import lightning.pytorch as pl
1919
import pytest
2020
import torch
21+
from torch._dynamo import OptimizedModule
2122
from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment
2223
from lightning.fabric.utilities.distributed import _distributed_is_initialized
2324
from lightning.pytorch import Trainer
@@ -448,3 +449,31 @@ def creates_processes_externally(self):
448449
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
449450
):
450451
trainer.fit(model)
452+
453+
454+
@RunIf(dynamo=True)
455+
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
456+
@mock.patch.dict(os.environ, {})
457+
def test_reapply_compile(tmp_path):
458+
"""Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
459+
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp", max_steps=2, logger=False)
460+
461+
model = BoringModel()
462+
compile_kwargs = {"mode": "reduce-overhead"}
463+
compiled_model = torch.compile(model, **compile_kwargs)
464+
torch.compile.reset_mock()
465+
466+
trainer.fit(compiled_model)
467+
trainer_model = trainer.strategy.model
468+
469+
assert isinstance(trainer_model, OptimizedModule)
470+
assert isinstance(trainer_model._orig_mod, DistributedDataParallel)
471+
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
472+
torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs)
473+
474+
assert trainer_model._orig_mod.module == model
475+
476+
# Smoke-testing forward to ensure we don't get compilation errors
477+
for _ in range(3):
478+
trainer_model(torch.randn(2, 32, device="cpu")).sum().backward()
479+
assert True

0 commit comments

Comments
 (0)