Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,10 +481,15 @@ def on_advance_end(self) -> None:

trainer._logger_connector.on_epoch_end()

if not self.restarting and self.epoch_loop._num_ready_batches_reached():
# since metric-based schedulers require access to metrics and those are not currently saved in the
# checkpoint, the plateau schedulers shouldn't be updated
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
# since metric-based schedulers require access to metrics and those are not currently saved in the
# checkpoint, the plateau schedulers shouldn't be updated when restarting
# only update plateau schedulers if validation ran this epoch to ensure monitored metrics are available
if (
not self.restarting
and self.epoch_loop._num_ready_batches_reached()
and self.epoch_loop._should_check_val_epoch()
):
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)

# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
# even when the batch loop has finished
Expand Down
39 changes: 39 additions & 0 deletions tests/tests_pytorch/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,45 @@ def configure_optimizers(self):
)


def test_reducelronplateau_with_check_val_every_n_epoch(tmp_path):
"""Test that ReduceLROnPlateau works correctly when validation runs every N epochs."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)["loss"]
self.log("train/loss", loss)
return loss

def validation_step(self, batch, batch_idx):
loss = super().validation_step(batch, batch_idx)["x"]
self.log("val/loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters())
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer),
"monitor": "val/loss",
},
}

model = TestModel()
# Validation runs every 2 epochs, but scheduler should only update on validation epochs
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=3,
limit_train_batches=2,
limit_val_batches=2,
check_val_every_n_epoch=2,
enable_progress_bar=False,
enable_model_summary=False,
)
# This should not raise an error about missing val/loss metric
trainer.fit(model)


def test_optimizer_return_options(tmp_path):
trainer = Trainer(default_root_dir=tmp_path)
model = BoringModel()
Expand Down
Loading