Skip to content

Commit bde5d8f

Browse files
authored
Move implementation to scales comparison
1 parent 18d760a commit bde5d8f

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

src/lightning_lite/plugins/precision/native_amp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def optimizer_step(
7474
return super().optimizer_step(optimizer, **kwargs)
7575
if isinstance(optimizer, LBFGS):
7676
raise TypeError("Native AMP and the LBFGS optimizer are not compatible.")
77+
previous_scale = self.scaler.get_scale()
7778
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
7879
step_output = self.scaler.step(optimizer, **kwargs)
80+
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
7981
self.scaler.update()
8082
return step_output
8183

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,10 +1585,6 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
15851585
scheduler.step(epoch=self.current_epoch)
15861586
15871587
"""
1588-
optimizer = self.trainer.optimizers[optimizer_idx]
1589-
if hasattr(optimizer, "_step_count") and optimizer._step_count <= 0:
1590-
return
1591-
15921588
if metric is None:
15931589
scheduler.step() # type: ignore[call-arg]
15941590
else:

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
216216

217217
# update non-plateau LR schedulers
218218
# update epoch-interval ones only when we are at the end of training epoch
219-
self.update_lr_schedulers("step", update_plateau_schedulers=False)
220-
if self._num_ready_batches_reached():
219+
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
220+
self.update_lr_schedulers("step", update_plateau_schedulers=False)
221+
elif self._num_ready_batches_reached():
221222
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
222223

223224
batch_end_outputs = self._prepare_outputs_training_batch_end(

0 commit comments

Comments
 (0)