diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index d5fc1f0c1cc2a..429d891a9169e 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -88,9 +88,11 @@ def optimizer_step( return super().optimizer_step(optimizer, **kwargs) if isinstance(optimizer, LBFGS): raise TypeError("AMP and the LBFGS optimizer are not compatible.") + previous_scale = self.scaler.get_scale() # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type] self.scaler.update() + optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale return step_output @override diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 3d01780b705fe..86d9f4819d93c 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -498,6 +498,9 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) ) continue + if getattr(self.trainer.optimizers[config.opt_idx], "_skip_next_scheduler_step", False): + continue + self.scheduler_progress.increment_ready() # update LR diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 5ea62233e1f69..5b63f59ee7b44 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -92,8 +92,10 @@ def optimizer_step( # type: ignore[override] # in manual optimization, the closure does not return a value if not skip_unscaling: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found - step_output = self.scaler.step(optimizer, **kwargs) # type: ignore[arg-type] + previous_scale = self.scaler.get_scale() + step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() + optimizer._skip_next_scheduler_step = self.scaler.get_scale() < previous_scale return step_output return closure_result diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 73507f085936b..e3ca8fc013255 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -70,6 +70,7 @@ def test_amp_precision_backward(): def test_amp_precision_optimizer_step_with_scaler(): precision = MixedPrecision(precision="16-mixed", device="cuda") precision.scaler = Mock() + precision.scaler.get_scale = Mock(return_value=1.0) optimizer = Mock() precision.optimizer_step(optimizer, keyword="arg") diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index e943d0533cab5..8aff09e76bc54 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -303,11 +303,7 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_ "name": "optimizer_step", "args": (current_epoch, i, ANY, ANY), }, - *( - [{"name": "lr_scheduler_step", "args": (ANY, None)}] - if i == (trainer.num_training_batches - 1) - else [] - ), + *([{"name": "lr_scheduler_step", "args": ANY}] if i == (trainer.num_training_batches - 1) else []), {"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)}, {"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)}, ])