|
18 | 18 | import lightning.pytorch as pl
|
19 | 19 | import pytest
|
20 | 20 | import torch
|
| 21 | +from torch._dynamo import OptimizedModule |
21 | 22 | from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment
|
22 | 23 | from lightning.fabric.utilities.distributed import _distributed_is_initialized
|
23 | 24 | from lightning.pytorch import Trainer
|
@@ -448,3 +449,31 @@ def creates_processes_externally(self):
|
448 | 449 | RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
|
449 | 450 | ):
|
450 | 451 | trainer.fit(model)
|
| 452 | + |
| 453 | + |
| 454 | +@RunIf(dynamo=True) |
| 455 | +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) |
| 456 | +@mock.patch.dict(os.environ, {}) |
| 457 | +def test_reapply_compile(tmp_path): |
| 458 | + """Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" |
| 459 | + trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp", max_steps=2, logger=False) |
| 460 | + |
| 461 | + model = BoringModel() |
| 462 | + compile_kwargs = {"mode": "reduce-overhead"} |
| 463 | + compiled_model = torch.compile(model, **compile_kwargs) |
| 464 | + torch.compile.reset_mock() |
| 465 | + |
| 466 | + trainer.fit(compiled_model) |
| 467 | + trainer_model = trainer.strategy.model |
| 468 | + |
| 469 | + assert isinstance(trainer_model, OptimizedModule) |
| 470 | + assert isinstance(trainer_model._orig_mod, DistributedDataParallel) |
| 471 | + # Assert we called compile again with the same arguments, but on the DDP-wrapped module |
| 472 | + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) |
| 473 | + |
| 474 | + assert trainer_model._orig_mod.module == model |
| 475 | + |
| 476 | + # Smoke-testing forward to ensure we don't get compilation errors |
| 477 | + for _ in range(3): |
| 478 | + trainer_model(torch.randn(2, 32, device="cpu")).sum().backward() |
| 479 | + assert True |
0 commit comments