-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task
Milestone
Description
🐛 Bug
When the hyperparameters get passed in as a single argument container and then saved using
self.hyperparameters(container), the name of the argument does not get captured anymore as it was before.
To Reproduce
import os
import torch
from argparse import Namespace
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self, cfg: Namespace):
super().__init__()
self.save_hyperparameters(cfg)
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel(cfg=Namespace(a=2, b=3))
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model, train_dataloader=train_data)
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path)
print(checkpoint["hyper_parameters"])
assert checkpoint["hparams_name"] == "cfg"
if __name__ == "__main__":
run()Expected behavior
Assertion should pass, but instead the value was "hparams".
Environment
git bisect shows that PR #9125 caused it.
The commit was included in the 1.4.7 release and is blocking NVIDIA from upgrading their NeMo code base to the newest PL version.
Additional context
reported on slack
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task