@@ -135,6 +135,33 @@ def configure_model(self):
135135 parallelize (self .model , device_mesh = self .device_mesh )
136136
137137
138+ class SimpleCompiledModule (LightningModule ):
139+ def __init__ (self ):
140+ super ().__init__ ()
141+ self .model = nn .Sequential (nn .Linear (32 , 64 ), nn .ReLU (), nn .Linear (64 , 32 ))
142+ self ._loss = nn .MSELoss ()
143+
144+ def configure_model (self ):
145+ self .model = torch .compile (self .model )
146+
147+ def training_step (self , batch , batch_idx ):
148+ x , y = batch
149+ preds = self .model (x )
150+ return self ._loss (preds , y )
151+
152+ def configure_optimizers (self ):
153+ return torch .optim .AdamW (self .parameters (), lr = 1e-3 )
154+
155+
156+ def _compiled_model_dataloader (batch_size : int = 32 , num_batches : int = 2 ):
157+ total_samples = batch_size * num_batches
158+ generator = torch .Generator ().manual_seed (0 )
159+ features = torch .randn (total_samples , 32 , generator = generator )
160+ targets = torch .randn (total_samples , 32 , generator = generator )
161+ dataset = torch .utils .data .TensorDataset (features , targets )
162+ return DataLoader (dataset , batch_size = batch_size )
163+
164+
138165@RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 4 )
139166def test_setup_device_mesh (distributed ):
140167 from torch .distributed .device_mesh import DeviceMesh
@@ -237,6 +264,44 @@ def training_step(self, batch):
237264 trainer .fit (model )
238265
239266
267+ @RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 2 )
268+ def test_model_parallel_single_file_checkpoint_with_compile (distributed , tmp_path ):
269+ """Replicate the reporter's setup: compiled model + ModelParallel single-file checkpointing."""
270+
271+ seed_everything (0 )
272+ strategy = ModelParallelStrategy (
273+ data_parallel_size = 1 ,
274+ tensor_parallel_size = 1 ,
275+ save_distributed_checkpoint = False ,
276+ )
277+
278+ trainer = Trainer (
279+ accelerator = "auto" ,
280+ devices = 1 ,
281+ strategy = strategy ,
282+ max_steps = 2 ,
283+ limit_train_batches = 2 ,
284+ enable_checkpointing = False ,
285+ logger = False ,
286+ enable_progress_bar = False ,
287+ enable_model_summary = False ,
288+ default_root_dir = tmp_path ,
289+ )
290+
291+ dataloader = _compiled_model_dataloader (batch_size = 32 , num_batches = 2 )
292+
293+ with trainer .init_module (empty_init = True ):
294+ model = SimpleCompiledModule ()
295+
296+ trainer .fit (model , dataloader )
297+
298+ if trainer .is_global_zero :
299+ checkpoint_path = tmp_path / "compiled-model.ckpt"
300+ trainer .save_checkpoint (checkpoint_path )
301+
302+ trainer .strategy .barrier ()
303+
304+
240305@RunIf (min_torch = "2.4" , standalone = True , min_cuda_gpus = 4 )
241306@pytest .mark .parametrize (
242307 "compile" ,
0 commit comments