Skip to content

15min example not training #20627

@Risiamu

Description

@Risiamu

Bug description

Hello! I’m new to training a model with PyTorch Lightning, and I’ve run into a bit of an issue— it seems like the model’s parameters aren’t updating at all.
To figure out what’s going on, I wrote a checker function and tested it with the 15-minute example from the official website, but it still shows that the parameters aren’t changing.
Could there be something off with my checker? I’m feeling a bit stuck and unsure of what to try next. Any help would be greatly appreciated!

What version are you seeing the problem on?

v2.5

How to reproduce the bug

def compare_model_parameter_dicts(dict1, dict2):
    """
    Compare two sets of named parameters from PyTorch nn.Module.

    Args:
        params1: First model's named_parameters() output (iterator of (name, param))
        params2: Second model's named_parameters() output (iterator of (name, param))
    """


    # Check if parameter names match
    if set(dict1.keys()) != set(dict2.keys()):
        print("Parameter name mismatch!")
        print(f"Params1 names: {set(dict1.keys())}")
        print(f"Params2 names: {set(dict2.keys())}")
        return

    # Compare each parameter
    differences_found = False
    for name in dict1:
        param1 = dict1[name]
        param2 = dict2[name]

        # Check if shapes match
        if param1.shape != param2.shape:
            print(f"Shape mismatch for {name}:")
            print(f"Param1 shape: {param1.shape}")
            print(f"Param2 shape: {param2.shape}")
            differences_found = True
            continue

        # Check if values are identical
        if not torch.allclose(param1, param2, rtol=1e-5, atol=1e-8):
            differences_found = True
            print(f"\nDifferences found in {name}:")

            # Find differing elements
            diff_mask = ~torch.isclose(param1, param2, rtol=1e-5, atol=1e-8)
            diff_indices = torch.nonzero(diff_mask, as_tuple=False)

            # Print differing values
            for idx in diff_indices:
                idx_tuple = tuple(idx.tolist())
                val1 = param1[idx_tuple].item()
                val2 = param2[idx_tuple].item()
                print(f"Index {idx_tuple}:")
                print(f"  Value 1: {val1}")
                print(f"  Value 2: {val2}")

    if not differences_found:
        print("All parameters are identical!")


import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torch
from torchhelpers.parameter import compare_model_parameter_dicts
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.initial_model_param_dict = 0

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        if self.initial_model_param_dict is 0:
            self.initial_model_param_dict = {name: param for name, param in self.named_parameters()}
        else:
            compare_model_parameter_dicts({name: param for name, param in self.named_parameters()}, self.initial_model_param_dict)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

Error messages and logs

Epoch 0:  93%|█████████▎| 93/100 [00:00<00:00, 233.69it/s, v_num=3]All parameters are identical!
Epoch 0:  94%|█████████▍| 94/100 [00:00<00:00, 234.44it/s, v_num=3]All parameters are identical!
Epoch 0:  95%|█████████▌| 95/100 [00:00<00:00, 235.17it/s, v_num=3]All parameters are identical!
Epoch 0:  96%|█████████▌| 96/100 [00:00<00:00, 235.91it/s, v_num=3]All parameters are identical!
Epoch 0:  97%|█████████▋| 97/100 [00:00<00:00, 236.62it/s, v_num=3]All parameters are identical!
Epoch 0:  98%|█████████▊| 98/100 [00:00<00:00, 237.32it/s, v_num=3]All parameters are identical!
Epoch 0:  99%|█████████▉| 99/100 [00:00<00:00, 238.00it/s, v_num=3]All parameters are identical!
Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 234.96it/s, v_num=3]

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions