Skip to content

Commit 860a9d8

Browse files
dnanutiIsaevIlya
authored andcommitted
Add Lightning checkpoint load test (#152)
1 parent 3020b15 commit 860a9d8

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,26 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
171171
_verify_equal_state_dict(model.state_dict(), loaded_checkpoint["state_dict"])
172172

173173

174+
def test_compatibility_with_lightning_checkpoint_load(checkpoint_directory):
175+
nonce = random.randrange(2**64)
176+
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
177+
dataloader = DataLoader(dataset, num_workers=3)
178+
model = LightningTransformer(vocab_size=dataset.vocab_size)
179+
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
180+
trainer = L.Trainer(
181+
default_root_dir=checkpoint_directory.s3_uri,
182+
plugins=[s3_lightning_checkpoint],
183+
max_epochs=1,
184+
max_steps=3,
185+
)
186+
trainer.fit(model, dataloader)
187+
checkpoint_key = "lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt"
188+
checkpoint_s3_uri = f"{checkpoint_directory.s3_uri}{checkpoint_key}"
189+
new_model = LightningTransformer(vocab_size=dataset.vocab_size)
190+
trainer.fit(new_model, dataloader, ckpt_path=checkpoint_s3_uri)
191+
_verify_equal_state_dict(model.state_dict(), new_model.state_dict())
192+
193+
174194
def test_nn_checkpointing(checkpoint_directory):
175195
nn_model = Net()
176196
checkpoint_name = "lightning_neural_network_model.pt"

0 commit comments

Comments
 (0)