Skip to content

Commit a237283

Browse files
committed
Move implementation to scales comparison
1 parent 1b3365e commit a237283

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

src/lightning_fabric/plugins/precision/native_amp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def optimizer_step(
7373
return super().optimizer_step(optimizer, **kwargs)
7474
if isinstance(optimizer, LBFGS):
7575
raise TypeError("Native AMP and the LBFGS optimizer are not compatible.")
76+
previous_scale = self.scaler.get_scale()
7677
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
7778
step_output = self.scaler.step(optimizer, **kwargs)
79+
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
7880
self.scaler.update()
7981
return step_output
8082

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,10 +1649,6 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
16491649
scheduler.step(epoch=self.current_epoch)
16501650
16511651
"""
1652-
optimizer = self.trainer.optimizers[optimizer_idx]
1653-
if hasattr(optimizer, "_step_count") and optimizer._step_count <= 0:
1654-
return
1655-
16561652
if metric is None:
16571653
scheduler.step() # type: ignore[call-arg]
16581654
else:

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
216216

217217
# update non-plateau LR schedulers
218218
# update epoch-interval ones only when we are at the end of training epoch
219-
self.update_lr_schedulers("step", update_plateau_schedulers=False)
220-
if self._num_ready_batches_reached():
219+
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
220+
self.update_lr_schedulers("step", update_plateau_schedulers=False)
221+
elif self._num_ready_batches_reached():
221222
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
222223

223224
batch_end_outputs = self._prepare_outputs_training_batch_end(

0 commit comments

Comments
 (0)