Skip to content

Commit 8172433

Browse files
CopilotBorda
andcommitted
Fix ReduceLROnPlateau scheduler with check_val_every_n_epoch
- Only update plateau schedulers on epochs when validation runs - This prevents errors when monitored metrics are not available - Added test case for this scenario Co-authored-by: Borda <[email protected]>
1 parent 531c1e9 commit 8172433

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,10 @@ def on_advance_end(self) -> None:
483483

484484
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
485485
# since metric-based schedulers require access to metrics and those are not currently saved in the
486-
# checkpoint, the plateau schedulers shouldn't be updated
487-
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
486+
# checkpoint, the plateau schedulers shouldn't be updated when restarting
487+
# only update plateau schedulers if validation ran this epoch to ensure monitored metrics are available
488+
if self.epoch_loop._should_check_val_epoch():
489+
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)
488490

489491
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
490492
# even when the batch loop has finished

tests/tests_pytorch/trainer/optimization/test_optimizers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,45 @@ def configure_optimizers(self):
165165
)
166166

167167

168+
def test_reducelronplateau_with_check_val_every_n_epoch(tmp_path):
169+
"""Test that ReduceLROnPlateau works correctly when validation runs every N epochs."""
170+
171+
class TestModel(BoringModel):
172+
def training_step(self, batch, batch_idx):
173+
loss = super().training_step(batch, batch_idx)["loss"]
174+
self.log("train/loss", loss)
175+
return loss
176+
177+
def validation_step(self, batch, batch_idx):
178+
loss = super().validation_step(batch, batch_idx)["x"]
179+
self.log("val/loss", loss)
180+
return loss
181+
182+
def configure_optimizers(self):
183+
optimizer = optim.Adam(self.parameters())
184+
return {
185+
"optimizer": optimizer,
186+
"lr_scheduler": {
187+
"scheduler": optim.lr_scheduler.ReduceLROnPlateau(optimizer),
188+
"monitor": "val/loss",
189+
},
190+
}
191+
192+
model = TestModel()
193+
# Validation runs every 2 epochs, but scheduler should only update on validation epochs
194+
trainer = Trainer(
195+
default_root_dir=tmp_path,
196+
max_epochs=3,
197+
limit_train_batches=2,
198+
limit_val_batches=2,
199+
check_val_every_n_epoch=2,
200+
enable_progress_bar=False,
201+
enable_model_summary=False,
202+
)
203+
# This should not raise an error about missing val/loss metric
204+
trainer.fit(model)
205+
206+
168207
def test_optimizer_return_options(tmp_path):
169208
trainer = Trainer(default_root_dir=tmp_path)
170209
model = BoringModel()

0 commit comments

Comments
 (0)