diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a2e9ded0aeb2c..444f97e718179 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) +- Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055)) + + - Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index f667b5c501a10..aaadc3c38ed5e 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -106,7 +106,7 @@ def __init__( self._attr_name = attr_name self._early_exit = False - self.lr_finder: Optional[_LRFinder] = None + self.optimal_lr: Optional[_LRFinder] = None def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: with isolate_rng(): diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index b50bedb10d53f..35f7baa1408c9 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -71,26 +71,10 @@ class _LRFinder: Args: mode: either `linear` or `exponential`, how to increase lr after each step - lr_min: lr to start search from - lr_max: lr to stop search - num_training: number of steps to take between lr_min and lr_max - Example:: - # Run lr finder - lr_finder = trainer.lr_find(model) - - # Results stored in - lr_finder.results - - # Plot using - lr_finder.plot() - - # Get suggestion - lr = lr_finder.suggestion() - """ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None: @@ -138,10 +122,9 @@ def plot( """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point - show: if True, will show figure - ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created. + """ if not _MATPLOTLIB_AVAILABLE: raise MisconfigurationException( @@ -190,7 +173,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] """ losses = torch.tensor(self.results["loss"][skip_begin:-skip_end]) - losses = losses[torch.isfinite(losses)] + lrs = torch.tensor(self.results["lr"][skip_begin:-skip_end]) + is_finite = torch.isfinite(losses) + losses = losses[is_finite] + lrs = lrs[is_finite] if len(losses) < 2: # computing torch.gradient requires at least 2 points @@ -201,12 +187,12 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] self._optimal_idx = None return None - # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be - # incorrectly shifted by an offset - gradients = torch.gradient(losses)[0] # Unpack the tuple + gradients = torch.gradient(losses, spacing=[lrs])[0] # Compute the gradient of losses w.r.t. learning rates min_grad = torch.argmin(gradients).item() - - self._optimal_idx = min_grad + skip_begin + all_losses_idx = torch.arange(len(self.results["loss"])) + idx_non_skipped = all_losses_idx[skip_begin:-skip_end] + idx_finite = idx_non_skipped[is_finite] + self._optimal_idx = idx_finite[min_grad].item() # type: ignore return self.results["lr"][self._optimal_idx] diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ec894688ccb6c..d99d15321bd75 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -538,3 +538,79 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: suggested_lr = lr_finder.suggestion() assert math.isfinite(suggested_lr) assert math.isclose(model.lr, suggested_lr) + + +def test_gradient_correctness(): + """Test that torch.gradient uses correct spacing parameter.""" + lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20) + + # Synthetic example + lrs = torch.linspace(0, 2 * math.pi, steps=1000) + losses = torch.sin(lrs) + lr_finder.results = {"lr": lrs.tolist(), "loss": losses.tolist()} + + # Test the suggestion method + suggestion = lr_finder.suggestion(skip_begin=2, skip_end=2) + assert suggestion is not None + assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example" + + +def test_exponential_vs_linear_mode_gradient_difference(tmp_path): + """Test that exponential and linear modes produce different but valid suggestions. + + This verifies that the spacing fix works for both modes and that they behave differently as expected due to their + different lr progressions. + + """ + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.lr = 1e-3 + + seed_everything(42) + + # Test both modes with identical parameters + model_linear = TestModel() + model_exp = TestModel() + + trainer_linear = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer_exp = Trainer(default_root_dir=tmp_path, max_epochs=1) + + tuner_linear = Tuner(trainer_linear) + tuner_exp = Tuner(trainer_exp) + + lr_finder_linear = tuner_linear.lr_find(model_linear, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="linear") + lr_finder_exp = tuner_exp.lr_find(model_exp, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="exponential") + + # Both should produce valid suggestions + suggestion_linear = lr_finder_linear.suggestion() + suggestion_exp = lr_finder_exp.suggestion() + + assert suggestion_linear is not None + assert suggestion_exp is not None + assert suggestion_linear > 0 + assert suggestion_exp > 0 + + # Verify that gradient computation uses correct spacing for both modes + for lr_finder, mode in [(lr_finder_linear, "linear"), (lr_finder_exp, "exponential")]: + losses = torch.tensor(lr_finder.results["loss"][10:-10]) + lrs = torch.tensor(lr_finder.results["lr"][10:-10]) + is_finite = torch.isfinite(losses) + losses_filtered = losses[is_finite] + lrs_filtered = lrs[is_finite] + + if len(losses_filtered) >= 2: + # Test that gradient computation works and produces finite results + gradients = torch.gradient(losses_filtered, spacing=[lrs_filtered])[0] + assert torch.isfinite(gradients).all(), f"Non-finite gradients in {mode} mode" + assert len(gradients) == len(losses_filtered) + + # Verify gradients with spacing differ from gradients without spacing + gradients_no_spacing = torch.gradient(losses_filtered)[0] + + # For exponential mode, these should definitely be different, for linear mode, they might be similar + if mode == "exponential": + assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), ( + "Gradients should differ significantly in exponential mode when using proper spacing" + )