Skip to content

Commit db3d718

Browse files
committed
Add regression test for ModelParallel single-file checkpoint
1 parent b09e96e commit db3d718

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/tests_pytorch/strategies/test_model_parallel_integration.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,38 @@ def training_step(self, batch):
237237
trainer.fit(model)
238238

239239

240+
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
241+
def test_model_parallel_single_file_checkpoint_with_compile(distributed, tmp_path):
242+
"""Ensure assembling non-distributed checkpoints works when the model is compiled (torch.compile)."""
243+
244+
seed_everything(0)
245+
strategy = ModelParallelStrategy(
246+
data_parallel_size=1,
247+
tensor_parallel_size=2,
248+
save_distributed_checkpoint=False,
249+
)
250+
251+
trainer = Trainer(
252+
accelerator="auto",
253+
devices=2,
254+
strategy=strategy,
255+
max_steps=2,
256+
limit_train_batches=2,
257+
logger=False,
258+
enable_model_summary=False,
259+
default_root_dir=tmp_path,
260+
)
261+
262+
with trainer.init_module(empty_init=True):
263+
model = FSDP2Model(compile=True)
264+
265+
trainer.fit(model)
266+
checkpoint_path = tmp_path / "compiled-model.ckpt"
267+
trainer.save_checkpoint(checkpoint_path)
268+
if trainer.is_global_zero:
269+
assert checkpoint_path.is_file()
270+
271+
240272
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
241273
@pytest.mark.parametrize(
242274
"compile",

0 commit comments

Comments
 (0)