-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
A similar issue has been described before with #17167, but briefly, when using manual optimization, the CheckpointCallback fails to save checkpoints at every_n_train_steps
or at every_n_epochs
. What is different in my code is that my optimization steps happen not within training_step
but rather inside on_train_batch_end
out of necessity. I'm not sure if this breaks the CheckpointCallback from working as expected.
class CustomModel(pl.LightningModule):
def __init__(self, ...):
super().__init__()
self.automatic_optimization = False
def training_step(self, ...):
input, target = batch
output_1, output_2 = self.forward(input, target)
return {'output_1': output_1, 'output_2': output_2}
def on_train_batch_end(self, ...):
output_1 = self.all_gather(training_step_outputs['output_1'], sync_grads=True)
output_2 = self.all_gather(training_step_outputs['output_2'], sync_grads=True)
opt = self.optimizers()
opt.zero_grad()
loss = self.loss_fn(output1, output2)
self.manual_backward(loss)
opt.step()
self.log('loss', loss, on_step=True, on_epoch=True)
return loss
### Checkpoint Callback
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints/", # Save directory
filename="model-{epoch:02d}"
save_top_k=-1, # Keep all checkpoints
every_n_train_steps=50,
save_on_train_epoch_end=True
)
What version are you seeing the problem on?
v2.5
How to reproduce the bug
Error messages and logs
No error message
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x