Skip to content

Commit 9f4d4ed

Browse files
committed
Move implementation to scales comparison
1 parent 2539dba commit 9f4d4ed

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

src/lightning_fabric/plugins/precision/native_amp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def optimizer_step(
7373
return super().optimizer_step(optimizer, **kwargs)
7474
if isinstance(optimizer, LBFGS):
7575
raise TypeError("Native AMP and the LBFGS optimizer are not compatible.")
76+
previous_scale = self.scaler.get_scale()
7677
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
7778
step_output = self.scaler.step(optimizer, **kwargs)
79+
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
7880
self.scaler.update()
7981
return step_output
8082

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,10 +1621,6 @@ def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
16211621
scheduler.step(epoch=self.current_epoch)
16221622
16231623
"""
1624-
optimizer = self.trainer.optimizers[optimizer_idx]
1625-
if hasattr(optimizer, "_step_count") and optimizer._step_count <= 0:
1626-
return
1627-
16281624
if metric is None:
16291625
scheduler.step() # type: ignore[call-arg]
16301626
else:

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
216216

217217
# update non-plateau LR schedulers
218218
# update epoch-interval ones only when we are at the end of training epoch
219-
self.update_lr_schedulers("step", update_plateau_schedulers=False)
220-
if self._num_ready_batches_reached():
219+
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
220+
self.update_lr_schedulers("step", update_plateau_schedulers=False)
221+
elif self._num_ready_batches_reached():
221222
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
222223

223224
batch_end_outputs = self._prepare_outputs_training_batch_end(

0 commit comments

Comments
 (0)