File tree Expand file tree Collapse file tree 1 file changed +32
-0
lines changed
tests/tests_pytorch/strategies Expand file tree Collapse file tree 1 file changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -237,6 +237,38 @@ def training_step(self, batch):
237237 trainer .fit (model )
238238
239239
240+ @RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 2 )
241+ def test_model_parallel_single_file_checkpoint_with_compile (distributed , tmp_path ):
242+ """Ensure assembling non-distributed checkpoints works when the model is compiled (torch.compile)."""
243+
244+ seed_everything (0 )
245+ strategy = ModelParallelStrategy (
246+ data_parallel_size = 1 ,
247+ tensor_parallel_size = 2 ,
248+ save_distributed_checkpoint = False ,
249+ )
250+
251+ trainer = Trainer (
252+ accelerator = "auto" ,
253+ devices = 2 ,
254+ strategy = strategy ,
255+ max_steps = 2 ,
256+ limit_train_batches = 2 ,
257+ logger = False ,
258+ enable_model_summary = False ,
259+ default_root_dir = tmp_path ,
260+ )
261+
262+ with trainer .init_module (empty_init = True ):
263+ model = FSDP2Model (compile = True )
264+
265+ trainer .fit (model )
266+ checkpoint_path = tmp_path / "compiled-model.ckpt"
267+ trainer .save_checkpoint (checkpoint_path )
268+ if trainer .is_global_zero :
269+ assert checkpoint_path .is_file ()
270+
271+
240272@RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 4 )
241273@pytest .mark .parametrize (
242274 "compile" ,
You can’t perform that action at this time.
0 commit comments