Skip to content

Commit d90cb7f

Browse files
maxoppeltrohitgr7carmoccatchatonkaushikb11
authored
Bugfix: Scheduler monitor for manual optimization (#7643)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent eaa16c7 commit d90cb7f

File tree

4 files changed

+102
-37
lines changed

4 files changed

+102
-37
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
233233
- Fixed `accumulate_grad_batches` not been recomputed during model reload ([#5334](https://github.com/PyTorchLightning/pytorch-lightning/pull/5334))
234234
- Fixed a `TypeError` when wrapping optimizers in the `HorovodPlugin` and running `Trainer.test` ([#7840](https://github.com/PyTorchLightning/pytorch-lightning/pull/7840))
235235
- Fixed `BackboneFinetuning` restoration ([#8501](https://github.com/PyTorchLightning/pytorch-lightning/pull/8501))
236+
- Fixed `lr_scheduler` with metric (e.g. `torch.optim.lr_scheduler.ReduceLROnPlateau`) when using `automatic_optimization = False` ([#7643](https://github.com/PyTorchLightning/pytorch-lightning/pull/7643))
236237

237238

238239
## [1.3.8] - 2021-07-01

docs/source/common/optimizers.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,21 @@ If you want to call ``lr_scheduler.step()`` every ``n`` steps/epochs, do the fol
230230
if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % n == 0:
231231
sch.step()
232232

233+
If you want to call schedulers that require a metric value after each epoch, consider doing the following:
234+
235+
.. testcode::
236+
237+
def __init__(self):
238+
super().__init__()
239+
self.automatic_optimization = False
240+
241+
def training_epoch_end(self, outputs):
242+
sch = self.lr_schedulers()
243+
244+
# If the selected scheduler is a ReduceLROnPlateau scheduler.
245+
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
246+
sch.step(self.trainer.callback_metrics["loss"])
247+
233248
-----
234249

235250
Use closure for LBFGS-like optimizers

pytorch_lightning/trainer/optimizers.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,8 @@ def configure_schedulers(
119119
lr_schedulers = []
120120
default_config = _get_default_scheduler_config()
121121
for scheduler in schedulers:
122-
if isinstance(scheduler, dict):
123-
# check provided keys
124-
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
125-
if extra_keys:
126-
rank_zero_warn(f"Found unsupported keys in the lr scheduler dict: {extra_keys}", RuntimeWarning)
127-
if "scheduler" not in scheduler:
128-
raise MisconfigurationException(
129-
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
130-
)
131-
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
132-
raise MisconfigurationException(
133-
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
134-
f' but is "{scheduler["interval"]}"'
135-
)
136-
if is_manual_optimization:
122+
if is_manual_optimization:
123+
if isinstance(scheduler, dict):
137124
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
138125
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
139126

@@ -144,30 +131,49 @@ def configure_schedulers(
144131
RuntimeWarning,
145132
)
146133

147-
scheduler["reduce_on_plateau"] = isinstance(
148-
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
149-
)
150-
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
151-
raise MisconfigurationException(
152-
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
153-
' For example: {"optimizer": optimizer, "lr_scheduler":'
154-
' {"scheduler": scheduler, "monitor": "your_loss"}}'
134+
scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
135+
lr_schedulers.append({**default_config, **scheduler})
136+
else:
137+
lr_schedulers.append({**default_config, "scheduler": scheduler})
138+
else:
139+
if isinstance(scheduler, dict):
140+
# check provided keys
141+
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
142+
if extra_keys:
143+
rank_zero_warn(f"Found unsupported keys in the lr scheduler dict: {extra_keys}", RuntimeWarning)
144+
if "scheduler" not in scheduler:
145+
raise MisconfigurationException(
146+
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
147+
)
148+
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
149+
raise MisconfigurationException(
150+
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
151+
f' but is "{scheduler["interval"]}"'
152+
)
153+
scheduler["reduce_on_plateau"] = isinstance(
154+
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
155155
)
156-
lr_schedulers.append({**default_config, **scheduler})
157-
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
158-
if monitor is None:
159-
raise MisconfigurationException(
160-
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used."
161-
" For example:"
162-
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
156+
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
157+
raise MisconfigurationException(
158+
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
159+
' For example: {"optimizer": optimizer, "lr_scheduler":'
160+
' {"scheduler": scheduler, "monitor": "your_loss"}}'
161+
)
162+
lr_schedulers.append({**default_config, **scheduler})
163+
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
164+
if monitor is None:
165+
raise MisconfigurationException(
166+
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
167+
" scheduler is used. For example:"
168+
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
169+
)
170+
lr_schedulers.append(
171+
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
163172
)
164-
lr_schedulers.append(
165-
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
166-
)
167-
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
168-
lr_schedulers.append({**default_config, "scheduler": scheduler})
169-
else:
170-
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
173+
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
174+
lr_schedulers.append({**default_config, "scheduler": scheduler})
175+
else:
176+
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
171177
return lr_schedulers
172178

173179

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,49 @@ def configure_optimizers(self):
957957
trainer.fit(model)
958958

959959

960+
@pytest.mark.parametrize("scheduler_as_dict", [True, False])
961+
def test_lr_schedulers_reduce_lr_on_plateau(tmpdir, scheduler_as_dict):
962+
class TestModel(BoringModel):
963+
def __init__(self, scheduler_as_dict):
964+
super().__init__()
965+
self.scheduler_as_dict = scheduler_as_dict
966+
self.automatic_optimization = False
967+
968+
def training_step(self, batch, batch_idx):
969+
return {"train_loss": torch.tensor([0.0])}
970+
971+
def training_epoch_end(self, outputs):
972+
scheduler = self.lr_schedulers()
973+
974+
loss = torch.stack([x["train_loss"] for x in outputs]).mean()
975+
scheduler.step(loss)
976+
977+
def configure_optimizers(self):
978+
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
979+
980+
if self.scheduler_as_dict:
981+
scheduler = {
982+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
983+
"monitor": "train_loss",
984+
}
985+
else:
986+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
987+
988+
return [optimizer], [scheduler]
989+
990+
model = TestModel(scheduler_as_dict=scheduler_as_dict)
991+
992+
trainer = Trainer(
993+
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, limit_test_batches=1
994+
)
995+
996+
if scheduler_as_dict:
997+
with pytest.warns(RuntimeWarning, match="but the keys will be ignored"):
998+
trainer.fit(model)
999+
else:
1000+
trainer.fit(model)
1001+
1002+
9601003
def test_lr_scheduler_step_not_called(tmpdir):
9611004
"""
9621005
Test `lr_scheduler.step()` is not called in manual optimization.

0 commit comments

Comments
 (0)