-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workingdistributedGeneric distributed-related topicGeneric distributed-related topichelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Milestone
Description
๐ Bug
Trying to restore a checkpoint to resume training but it fails with the below exceptions
RuntimeError: Error(s) in loading state_dict for BoringModel:
size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
self.lightning_module.load_state_dict(ckpt['state_dict'])
File "/home/ca5b7a03-2d901b-2d45e5-2d969e-2df8ccc075972b/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
To Reproduce
Run the following mode with commented out restore argument, then run it again with uncommenting the restore and you will see the exception.
import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from deepspeed.ops.adam import FusedAdam
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
return loss
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath='tests/',
filename='{epoch:02d}',
)
trainer = Trainer(
default_root_dir=os.getcwd(),
gpus=-1,
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
precision=16,
accelerator='ddp',
max_epochs=2,
plugins=[DeepSpeedPlugin(cpu_offload=False, stage=3)],
weights_summary=None,
callbacks=[checkpoint_callback],
#resume_from_checkpoint='tests/epoch=01.ckpt',
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)
if __name__ == '__main__':
run()Expected behavior
Training resume successfully from stored checkpoint.
Environment
Tried with lightning version: 1.2.10, 1.3.0.rc1 and master
pytorch: 1.7.1
OS: Ubuntu 18.04
@SeanNaren As discussed on slack ^^
Metadata
Metadata
Assignees
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workingdistributedGeneric distributed-related topicGeneric distributed-related topichelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task