How to save and load checkpointing using DeepSpeed plugin stage 3? #9321
Unanswered
yidong72
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment 1 reply
-
Here is a small example. Run it twice, once without modifications and a second time by increasing the max epochs and uncommenting the line for resume_from_checkpoint: import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/",
filename="{epoch:02d}",
)
trainer = Trainer(
# resume_from_checkpoint="checkpoints/epoch=9.ckpt",
max_epochs=1, # increase when resuming
gpus=2,
accelerator="ddp",
plugins=[DeepSpeedPlugin(stage=3)],
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
precision=16,
weights_summary=None,
callbacks=[checkpoint_callback],
)
trainer.fit(model, train_dataloader=train_data)
if __name__ == "__main__":
run()
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have been struggling figuring out how to save/load my model with DeepSpeed plugin. I cannot find any examples of doing it.
Here is how I setup the plugin
I use the
ModelCheckpoint
callback to save the checkpoints. It generates either a single checkpoint file or a directory ofpt
files depending on thesave_full_weights
state true of false.However I don't know how to load the checkpoint files. I tried either
Model.load_from_checkpoint
orTrainer(resume_from_checkpoint=)
methods, none of them works for me. I gotAttributeError: 'NoneType' object has no attribute 'trainer'
,Default process group has not been initialized, please make sure to call init_process_group.
errors.Could you show me a working example? Thanks.
Beta Was this translation helpful? Give feedback.
All reactions