Skip to content

Commit d4e476f

Browse files
committed
Add regression test for ModelParallel single-file checkpoint
1 parent b09e96e commit d4e476f

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

tests/tests_pytorch/strategies/test_model_parallel_integration.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,33 @@ def configure_model(self):
135135
parallelize(self.model, device_mesh=self.device_mesh)
136136

137137

138+
class SimpleCompiledModule(LightningModule):
139+
def __init__(self):
140+
super().__init__()
141+
self.model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 32))
142+
self._loss = nn.MSELoss()
143+
144+
def configure_model(self):
145+
self.model = torch.compile(self.model)
146+
147+
def training_step(self, batch, batch_idx):
148+
x, y = batch
149+
preds = self.model(x)
150+
return self._loss(preds, y)
151+
152+
def configure_optimizers(self):
153+
return torch.optim.AdamW(self.parameters(), lr=1e-3)
154+
155+
156+
def _compiled_model_dataloader(batch_size: int = 32, num_batches: int = 2):
157+
total_samples = batch_size * num_batches
158+
generator = torch.Generator().manual_seed(0)
159+
features = torch.randn(total_samples, 32, generator=generator)
160+
targets = torch.randn(total_samples, 32, generator=generator)
161+
dataset = torch.utils.data.TensorDataset(features, targets)
162+
return DataLoader(dataset, batch_size=batch_size)
163+
164+
138165
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
139166
def test_setup_device_mesh(distributed):
140167
from torch.distributed.device_mesh import DeviceMesh
@@ -237,6 +264,44 @@ def training_step(self, batch):
237264
trainer.fit(model)
238265

239266

267+
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
268+
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
269+
"""Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing."""
270+
271+
seed_everything(0)
272+
strategy = ModelParallelStrategy(
273+
data_parallel_size=1,
274+
tensor_parallel_size=1,
275+
save_distributed_checkpoint=False,
276+
)
277+
278+
trainer = Trainer(
279+
accelerator="auto",
280+
devices=1,
281+
strategy=strategy,
282+
max_steps=2,
283+
limit_train_batches=2,
284+
enable_checkpointing=False,
285+
logger=False,
286+
enable_progress_bar=False,
287+
enable_model_summary=False,
288+
default_root_dir=tmp_path,
289+
)
290+
291+
dataloader = _compiled_model_dataloader(batch_size=32, num_batches=2)
292+
293+
with trainer.init_module(empty_init=True):
294+
model = SimpleCompiledModule()
295+
296+
trainer.fit(model, dataloader)
297+
298+
if trainer.is_global_zero:
299+
checkpoint_path = tmp_path / "compiled-model.ckpt"
300+
trainer.save_checkpoint(checkpoint_path)
301+
302+
trainer.strategy.barrier()
303+
304+
240305
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
241306
@pytest.mark.parametrize(
242307
"compile",

0 commit comments

Comments
 (0)