Skip to content

Model restore fails from stored checkpoint when using Deepspeedย #7282

@gurvindersingh

Description

@gurvindersingh

๐Ÿ› Bug

Trying to restore a checkpoint to resume training but it fails with the below exceptions

RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.lightning_module.load_state_dict(ckpt['state_dict'])
  File "/home/ca5b7a03-2d901b-2d45e5-2d969e-2df8ccc075972b/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BoringModel:
        size mismatch for layer.weight: copying a param with shape torch.Size([2, 32]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for layer.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).

To Reproduce

Run the following mode with commented out restore argument, then run it again with uncommenting the restore and you will see the exception.

import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from deepspeed.ops.adam import FusedAdam


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    checkpoint_callback = ModelCheckpoint(
        dirpath='tests/',
        filename='{epoch:02d}',
    )
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        gpus=-1,
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        precision=16,
        accelerator='ddp',
        max_epochs=2,
        plugins=[DeepSpeedPlugin(cpu_offload=False, stage=3)],
        weights_summary=None,
        callbacks=[checkpoint_callback],
        #resume_from_checkpoint='tests/epoch=01.ckpt',
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
    trainer.test(model, test_dataloaders=test_data)


if __name__ == '__main__':
    run()

Expected behavior

Training resume successfully from stored checkpoint.

Environment

Tried with lightning version: 1.2.10, 1.3.0.rc1 and master
pytorch: 1.7.1
OS: Ubuntu 18.04

@SeanNaren As discussed on slack ^^

Metadata

Metadata

Assignees

Labels

3rd partyRelated to a 3rd-partybugSomething isn't workingdistributedGeneric distributed-related topichelp wantedOpen to be worked onpriority: 1Medium priority task

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions