Skip to content

Commit 4e71c28

Browse files
committed
Fix implementation
1 parent 17a48b3 commit 4e71c28

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

src/lightning_fabric/plugins/precision/native_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def optimizer_step(
7676
previous_scale = self.scaler.get_scale()
7777
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
7878
step_output = self.scaler.step(optimizer, **kwargs)
79-
model._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
8079
self.scaler.update()
80+
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
8181
return step_output
8282

8383
def state_dict(self) -> Dict[str, Any]:

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
216216

217217
# update non-plateau LR schedulers
218218
# update epoch-interval ones only when we are at the end of training epoch
219-
if not getattr(self.trainer.lightning_module, "_skip_next_scheduler_step", False):
220-
self.update_lr_schedulers("step", update_plateau_schedulers=False)
221-
elif self._num_ready_batches_reached():
219+
self.update_lr_schedulers("step", update_plateau_schedulers=False)
220+
if self._num_ready_batches_reached():
222221
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)
223222

224223
batch_end_outputs = self._prepare_outputs_training_batch_end(
@@ -451,6 +450,9 @@ def _update_learning_rates(
451450
)
452451
continue
453452

453+
if getattr(self.trainer.optimizers[config.opt_idx], "_skip_next_scheduler_step", False):
454+
continue
455+
454456
self.scheduler_progress.increment_ready()
455457

456458
# update LR

src/pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ def optimizer_step( # type: ignore[override]
8585
# in manual optimization, the closure does not return a value
8686
if not model.automatic_optimization or not skipped_backward:
8787
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
88+
previous_scale = self.scaler.get_scale()
8889
step_output = self.scaler.step(optimizer, **kwargs)
8990
self.scaler.update()
91+
optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale
9092
return step_output
9193
return closure_result
9294

tests/tests_fabric/plugins/precision/test_native_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_native_amp_precision_backward():
6969
def test_native_amp_precision_optimizer_step_with_scaler():
7070
precision = MixedPrecision(precision="mixed", device="cuda")
7171
precision.scaler = Mock()
72+
precision.scaler.get_scale = Mock(return_value=1.0)
7273
optimizer = Mock()
7374

7475
precision.optimizer_step(optimizer, keyword="arg")

0 commit comments

Comments
 (0)