Skip to content

Error learning rate when load ckpt for cotinue training if check_val_every_n_epoch > 1 #20495

@razgzy

Description

@razgzy

Bug description

I run with LightningCLI. Set check_val_every_n_epoch > 1 (e.g. 2) to run an experiment with 20 max_epoches, the model ckpt is save by lightning.pytorch.callbacks.ModelCheckpoint. The learning rate schedular is torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.trainer.max_epochs, T_mult=1, eta_min=self.eta_min) and update each epoch. When I load a ckpt (e.g. saved at epoch 3) to continue training, the learning rate will update 1 epoch quicker than expected.
image
Here the red is original learning rate curve and the yellow is the continued. The lr is logged by lightning.pytorch.callbacks.LearningRateMonitor.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

the trainer setting
seed_everything: 0
trainer:
  accelerator: gpu
  strategy: auto
  devices:
  - 0
  num_nodes: 1
  precision: 32-true
  logger:
    class_path: lightning.pytorch.loggers.TensorBoardLogger
    init_args:
      save_dir: runs
      name: name
      version: null
      log_graph: false
      default_hp_metric: true
      prefix: ''
      sub_dir: test
  callbacks:
  - class_path: utils.log_manager.LogManager
  - class_path: lightning.pytorch.callbacks.LearningRateMonitor
    init_args: 
      logging_interval: epoch
  - class_path: lightning.pytorch.callbacks.ModelCheckpoint
    init_args:
      dirpath: null
      filename: 'epoch={epoch}-psnr={hp_metric:.4f}'
      monitor: hp_metric
      verbose: false
      save_last: true
      save_top_k: 5
      save_weights_only: false
      mode: max
      auto_insert_metric_name: false
      every_n_train_steps: null
      train_time_interval: null
      every_n_epochs: 1
      save_on_train_epoch_end: null
  fast_dev_run: false
  max_epochs: 20
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 2
  num_sanity_val_steps: null
  log_every_n_steps: 20
  enable_checkpointing: null
  enable_progress_bar: null
  enable_model_summary: null
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: null
  benchmark: null
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: false
  barebones: false
  plugins: null
  sync_batchnorm: false
  reload_dataloaders_every_n_epochs: 0
  default_root_dir: null
 def configure_optimizers(self):
        params = []
        params.append({'params': self.network.parameters(), 'lr': self.lr, 'weight_decay': self.weight_decay, 'name': 'network'})
        optimizer = torch.optim.AdamW(params)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.trainer.max_epochs, T_mult=1, eta_min=self.eta_min)
        lr_scheduler_config = {
            "scheduler": lr_scheduler,
            "interval": 'epoch',
            "name": 'AdamW'
        }
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions