File tree Expand file tree Collapse file tree 1 file changed +20
-0
lines changed
Expand file tree Collapse file tree 1 file changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -171,6 +171,26 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
171171 _verify_equal_state_dict (model .state_dict (), loaded_checkpoint ["state_dict" ])
172172
173173
174+ def test_compatibility_with_lightning_checkpoint_load (checkpoint_directory ):
175+ nonce = random .randrange (2 ** 64 )
176+ dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
177+ dataloader = DataLoader (dataset , num_workers = 3 )
178+ model = LightningTransformer (vocab_size = dataset .vocab_size )
179+ s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
180+ trainer = L .Trainer (
181+ default_root_dir = checkpoint_directory .s3_uri ,
182+ plugins = [s3_lightning_checkpoint ],
183+ max_epochs = 1 ,
184+ max_steps = 3 ,
185+ )
186+ trainer .fit (model , dataloader )
187+ checkpoint_key = "lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt"
188+ checkpoint_s3_uri = f"{ checkpoint_directory .s3_uri } { checkpoint_key } "
189+ new_model = LightningTransformer (vocab_size = dataset .vocab_size )
190+ trainer .fit (new_model , dataloader , ckpt_path = checkpoint_s3_uri )
191+ _verify_equal_state_dict (model .state_dict (), new_model .state_dict ())
192+
193+
174194def test_nn_checkpointing (checkpoint_directory ):
175195 nn_model = Net ()
176196 checkpoint_name = "lightning_neural_network_model.pt"
You can’t perform that action at this time.
0 commit comments