diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index bae7f876c8211..355f49b6c3cdf 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -16,6 +16,7 @@ import copy import logging import numbers +import warnings import weakref from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager, nullcontext @@ -1329,10 +1330,26 @@ def lr_scheduler_step(self, scheduler, metric): scheduler.step(epoch=self.current_epoch) """ - if metric is None: - scheduler.step() # type: ignore[call-arg] - else: - scheduler.step(metric) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + + if metric is None: + scheduler.step() # type: ignore[call-arg] + else: + scheduler.step(metric) + + for w in caught: + msg = str(w.message) + + if "lr_scheduler.step()" in msg and "optimizer.step()" in msg: + msg = ( + f"{msg} Lightning note: When training in mixed/half precision " + "(e.g., 16-bit), an overflow on the first iteration can skip the " + "optimizer step; the scheduler then runs before any optimizer " + "step, which surfaces this warning." + ) + + rank_zero_warn(msg, category=w.category, stacklevel=2) def optimizer_step( self, diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index c33488a4f2626..3ecd0a7a1b1e8 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -446,6 +446,45 @@ def test_lightning_module_scriptable(): torch.jit.script(model) +@RunIf(min_cuda_gpus=1) +def test_lr_scheduler_step_warning_message(tmp_path): + class MinimalSchedulerModel(BoringModel): + def training_step(self, batch, batch_idx): + loss = self.step(batch) + if batch_idx == 0: + loss = loss * torch.tensor(float("nan"), device=self.device) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}} + + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="gpu", + devices=1, + precision="16-mixed", + max_steps=2, + limit_train_batches=2, + limit_val_batches=0, + logger=False, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + model = MinimalSchedulerModel() + + warning_match = ( + r"lr_scheduler\.step\(\).*optimizer\.step\(\).*" + r"Lightning note: When training in mixed/half precision \(e\.g\., 16-bit\)" + ) + + with pytest.warns(UserWarning, match=warning_match): + trainer.fit(model) + + def test_trainer_reference_recursively(): ensemble = LightningModule() inner = LightningModule()