-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x
Description
Bug description
Hi! 👋
I'm trying to save a model checkpoint every n epochs. As my model trains, I want to save checkpoints so I can explore performance at intervals throughout the run.
To do this, I'm leveraging the ModelCheckpoint class and creating a callback like the one below.
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints/every_10_epochs',
filename='epoch-{epoch:02d}',
every_n_epochs=10,
)
It seems like this flag is not working. It only saves one checkpoint, not one every n epochs
Am I misunderstanding how the checkpointing is supposed to work, or is this a bug?
What version are you seeing the problem on?
v2.4
How to reproduce the bug
This is a minimal example running 50 epochs with every_n_epochs=10. I expect the checkpointing to save 5 checkpoints, but only 1 is saved.
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint
import os
class MinimalModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(1, 1)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.layer(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.02)
# Create dummy data
x = torch.linspace(0, 1, 1000).unsqueeze(-1)
y = 3 * x + 0.5 + torch.randn_like(x) * 0.1
train_dataset = TensorDataset(x, y)
train_loader = DataLoader(train_dataset, batch_size=32)
# Create ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints/every_10_epochs',
filename='epoch-{epoch:02d}',
every_n_epochs=10,
)
# Create trainer with the callback
trainer = pl.Trainer(max_epochs=50, callbacks=[checkpoint_callback])
# Train model
model = MinimalModel()
trainer.fit(model, train_loader)
# Function to count models in a directory
def count_models(directory):
return len([f for f in os.listdir(directory) if f.endswith('.ckpt')])
# Count and print the number of saved models
num_checkpoints = count_models('checkpoints/every_10_epochs')
print(f"Number of checkpoints saved: {num_checkpoints}")
# Test if the number of checkpoints is correct
expected_checkpoints = 5 # We expect checkpoints at epochs 10, 20, 30, 40, and 50
if num_checkpoints == expected_checkpoints:
print("Test passed: Correct number of checkpoints saved.")
else:
print(f"Test failed: Expected {expected_checkpoints} checkpoints, but found {num_checkpoints}.")
# Print paths of saved checkpoints
print("\nSaved checkpoints:")
for checkpoint in os.listdir('checkpoints/every_10_epochs'):
if checkpoint.endswith('.ckpt'):
print(os.path.join('checkpoints/every_10_epochs', checkpoint))
Error messages and logs
`Trainer.fit` stopped: `max_epochs=50` reached.
Number of checkpoints saved: 1
Test failed: Expected 5 checkpoints, but found 1.
Saved checkpoints:
checkpoints/every_10_epochs/epoch-epoch=49.ckpt
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 2.4.1
#- Python version (e.g., 3.12): 3.11.9
#- OS (e.g., Linux): MacOS
#- CUDA/cuDNN version: NA
#- GPU models and configuration: ?
#- How you installed Lightning(`conda`, `pip`, source): poetry add Lightning (pip)
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x