Skip to content

Checkpoint callback fails with manual optimization #20674

@rohanshad

Description

@rohanshad

Bug description

A similar issue has been described before with #17167, but briefly, when using manual optimization, the CheckpointCallback fails to save checkpoints at every_n_train_steps or at every_n_epochs. What is different in my code is that my optimization steps happen not within training_step but rather inside on_train_batch_end out of necessity. I'm not sure if this breaks the CheckpointCallback from working as expected.

class CustomModel(pl.LightningModule):
    def __init__(self, ...):
        super().__init__()
        self.automatic_optimization = False
    
    def training_step(self, ...):
        input, target = batch
        output_1, output_2 = self.forward(input, target)
        return {'output_1': output_1, 'output_2': output_2}

    def on_train_batch_end(self, ...):
        output_1 = self.all_gather(training_step_outputs['output_1'], sync_grads=True)
        output_2 = self.all_gather(training_step_outputs['output_2'], sync_grads=True)

        opt = self.optimizers()
	opt.zero_grad()
        loss = self.loss_fn(output1, output2)
        self.manual_backward(loss)
	opt.step()
        self.log('loss', loss, on_step=True, on_epoch=True)

        return loss
### Checkpoint Callback
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",  # Save directory
    filename="model-{epoch:02d}"
    save_top_k=-1,  # Keep all checkpoints 
    every_n_train_steps=50,
    save_on_train_epoch_end=True 
)

What version are you seeing the problem on?

v2.5

How to reproduce the bug

Error messages and logs

No error message

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions