-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
Hello! I’m new to training a model with PyTorch Lightning, and I’ve run into a bit of an issue— it seems like the model’s parameters aren’t updating at all.
To figure out what’s going on, I wrote a checker function and tested it with the 15-minute example from the official website, but it still shows that the parameters aren’t changing.
Could there be something off with my checker? I’m feeling a bit stuck and unsure of what to try next. Any help would be greatly appreciated!
What version are you seeing the problem on?
v2.5
How to reproduce the bug
def compare_model_parameter_dicts(dict1, dict2):
"""
Compare two sets of named parameters from PyTorch nn.Module.
Args:
params1: First model's named_parameters() output (iterator of (name, param))
params2: Second model's named_parameters() output (iterator of (name, param))
"""
# Check if parameter names match
if set(dict1.keys()) != set(dict2.keys()):
print("Parameter name mismatch!")
print(f"Params1 names: {set(dict1.keys())}")
print(f"Params2 names: {set(dict2.keys())}")
return
# Compare each parameter
differences_found = False
for name in dict1:
param1 = dict1[name]
param2 = dict2[name]
# Check if shapes match
if param1.shape != param2.shape:
print(f"Shape mismatch for {name}:")
print(f"Param1 shape: {param1.shape}")
print(f"Param2 shape: {param2.shape}")
differences_found = True
continue
# Check if values are identical
if not torch.allclose(param1, param2, rtol=1e-5, atol=1e-8):
differences_found = True
print(f"\nDifferences found in {name}:")
# Find differing elements
diff_mask = ~torch.isclose(param1, param2, rtol=1e-5, atol=1e-8)
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
# Print differing values
for idx in diff_indices:
idx_tuple = tuple(idx.tolist())
val1 = param1[idx_tuple].item()
val2 = param2[idx_tuple].item()
print(f"Index {idx_tuple}:")
print(f" Value 1: {val1}")
print(f" Value 2: {val2}")
if not differences_found:
print("All parameters are identical!")
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torch
from torchhelpers.parameter import compare_model_parameter_dicts
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# define the LightningModule
class LitAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.initial_model_param_dict = 0
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
if self.initial_model_param_dict is 0:
self.initial_model_param_dict = {name: param for name, param in self.named_parameters()}
else:
compare_model_parameter_dicts({name: param for name, param in self.named_parameters()}, self.initial_model_param_dict)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)
# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()
# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)Error messages and logs
Epoch 0: 93%|█████████▎| 93/100 [00:00<00:00, 233.69it/s, v_num=3]All parameters are identical!
Epoch 0: 94%|█████████▍| 94/100 [00:00<00:00, 234.44it/s, v_num=3]All parameters are identical!
Epoch 0: 95%|█████████▌| 95/100 [00:00<00:00, 235.17it/s, v_num=3]All parameters are identical!
Epoch 0: 96%|█████████▌| 96/100 [00:00<00:00, 235.91it/s, v_num=3]All parameters are identical!
Epoch 0: 97%|█████████▋| 97/100 [00:00<00:00, 236.62it/s, v_num=3]All parameters are identical!
Epoch 0: 98%|█████████▊| 98/100 [00:00<00:00, 237.32it/s, v_num=3]All parameters are identical!
Epoch 0: 99%|█████████▉| 99/100 [00:00<00:00, 238.00it/s, v_num=3]All parameters are identical!
Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 234.96it/s, v_num=3]
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x