Skip to content

incorrect "hparam_name" saved to checkpoint when saving from a params in a container #9631

@awaelchli

Description

@awaelchli

🐛 Bug

When the hyperparameters get passed in as a single argument container and then saved using
self.hyperparameters(container), the name of the argument does not get captured anymore as it was before.

To Reproduce

import os

import torch
from argparse import Namespace
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self, cfg: Namespace):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel(cfg=Namespace(a=2, b=3))
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.fit(model, train_dataloader=train_data)

    checkpoint = torch.load(trainer.checkpoint_callback.best_model_path)
    print(checkpoint["hyper_parameters"])
    assert checkpoint["hparams_name"] == "cfg" 


if __name__ == "__main__":
    run()

Expected behavior

Assertion should pass, but instead the value was "hparams".

Environment

git bisect shows that PR #9125 caused it.
The commit was included in the 1.4.7 release and is blocking NVIDIA from upgrading their NeMo code base to the newest PL version.

Additional context

reported on slack

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointinghelp wantedOpen to be worked onpriority: 0High priority task

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions