Skip to content

Commit 13f15fa

Browse files
awaelchlilexierule
authored andcommitted
fix plateau scheduler stepping on incomplete epoch (#8861)
1 parent e622bca commit 13f15fa

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Fixed `accelerator=ddp` choice for CPU ([#8645](https://github.com/PyTorchLightning/pytorch-lightning/pull/8645))
2525

2626

27+
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
28+
29+
2730
## [1.4.0] - 2021-07-27
2831

2932
### Added

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
236236
self.trainer.call_hook("on_epoch_end")
237237
self.trainer.logger_connector.on_epoch_end()
238238

239-
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)
239+
if self._num_training_batches_reached(self.is_last_batch):
240+
self.update_lr_schedulers("epoch", update_plateau_schedulers=True)
240241

241242
epoch_output = self._epoch_output
242243
# free memory

tests/trainer/optimization/test_optimizers.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning import Callback, Trainer
2121
from pytorch_lightning.callbacks import ModelCheckpoint
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
23-
from tests.base import EvalModelTemplate
2423
from tests.helpers.boring_model import BoringModel
2524
from tests.helpers.runif import RunIf
2625

@@ -79,7 +78,7 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir):
7978
"""
8079
Test exception when a ReduceLROnPlateau is used with no monitor
8180
"""
82-
model = EvalModelTemplate()
81+
model = BoringModel()
8382
optimizer = optim.Adam(model.parameters())
8483
model.configure_optimizers = lambda: ([optimizer], [optim.lr_scheduler.ReduceLROnPlateau(optimizer)])
8584
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
@@ -93,7 +92,7 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir):
9392
"""
9493
Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor
9594
"""
96-
model = EvalModelTemplate()
95+
model = BoringModel()
9796
optimizer = optim.Adam(model.parameters())
9897
model.configure_optimizers = lambda: {
9998
"optimizer": optimizer,
@@ -376,33 +375,47 @@ def configure_optimizers(self):
376375
trainer.fit(model)
377376

378377

379-
def test_lr_scheduler_strict(tmpdir):
378+
@pytest.mark.parametrize("complete_epoch", [True, False])
379+
@mock.patch("torch.optim.lr_scheduler.ReduceLROnPlateau.step")
380+
def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch):
380381
"""
381382
Test "strict" support in lr_scheduler dict
382383
"""
383-
model = EvalModelTemplate()
384+
model = BoringModel()
384385
optimizer = optim.Adam(model.parameters())
385386
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
386-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
387+
max_epochs = 1 if complete_epoch else None
388+
max_steps = None if complete_epoch else 1
389+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps)
387390

388391
model.configure_optimizers = lambda: {
389392
"optimizer": optimizer,
390393
"lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": True},
391394
}
392-
with pytest.raises(
393-
MisconfigurationException,
394-
match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:",
395-
):
395+
396+
if complete_epoch:
397+
with pytest.raises(
398+
MisconfigurationException,
399+
match=r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:",
400+
):
401+
trainer.fit(model)
402+
else:
396403
trainer.fit(model)
397404

405+
step_mock.assert_not_called()
406+
398407
model.configure_optimizers = lambda: {
399408
"optimizer": optimizer,
400409
"lr_scheduler": {"scheduler": scheduler, "monitor": "giraffe", "strict": False},
401410
}
402-
with pytest.warns(
403-
RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
404-
):
405-
trainer.fit(model)
411+
412+
if complete_epoch:
413+
with pytest.warns(
414+
RuntimeWarning, match=r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
415+
):
416+
trainer.fit(model)
417+
418+
step_mock.assert_not_called()
406419

407420

408421
def test_unknown_configure_optimizers_raises(tmpdir):

0 commit comments

Comments
 (0)