Skip to content

Commit 82f9a7d

Browse files
committed
Add regression test for ModelParallel single-file checkpoint
1 parent b09e96e commit 82f9a7d

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/tests_pytorch/strategies/test_model_parallel_integration.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,37 @@ def training_step(self, batch):
237237
trainer.fit(model)
238238

239239

240+
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
241+
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
242+
"""Ensure assembling non-distributed checkpoints works when the model is compiled (torch.compile)."""
243+
244+
seed_everything(0)
245+
strategy = ModelParallelStrategy(
246+
data_parallel_size=1,
247+
tensor_parallel_size=2,
248+
save_distributed_checkpoint=False,
249+
)
250+
251+
trainer = Trainer(
252+
accelerator="auto",
253+
devices=2,
254+
strategy=strategy,
255+
max_steps=2,
256+
limit_train_batches=2,
257+
logger=False,
258+
enable_model_summary=False,
259+
default_root_dir=tmp_path,
260+
)
261+
262+
with trainer.init_module(empty_init=True):
263+
model = FSDP2Model(compile=True)
264+
265+
trainer.fit(model)
266+
checkpoint_path = tmp_path / "compiled-model.ckpt"
267+
trainer.save_checkpoint(checkpoint_path)
268+
assert checkpoint_path.is_file()
269+
270+
240271
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
241272
@pytest.mark.parametrize(
242273
"compile",

0 commit comments

Comments
 (0)