Skip to content

Commit 105bb20

Browse files
authored
Fix lr_finder gradient calculation for mode="exponential" (#21055)
* fix impl for exponential spacing * add testing * small doc fixes
1 parent 60883a9 commit 105bb20

File tree

4 files changed

+90
-25
lines changed

4 files changed

+90
-25
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))
4242

4343

44+
- Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055))
45+
46+
4447
- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051))
4548

4649

src/lightning/pytorch/callbacks/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
self._attr_name = attr_name
107107

108108
self._early_exit = False
109-
self.lr_finder: Optional[_LRFinder] = None
109+
self.optimal_lr: Optional[_LRFinder] = None
110110

111111
def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
112112
with isolate_rng():

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,10 @@ class _LRFinder:
7171
7272
Args:
7373
mode: either `linear` or `exponential`, how to increase lr after each step
74-
7574
lr_min: lr to start search from
76-
7775
lr_max: lr to stop search
78-
7976
num_training: number of steps to take between lr_min and lr_max
8077
81-
Example::
82-
# Run lr finder
83-
lr_finder = trainer.lr_find(model)
84-
85-
# Results stored in
86-
lr_finder.results
87-
88-
# Plot using
89-
lr_finder.plot()
90-
91-
# Get suggestion
92-
lr = lr_finder.suggestion()
93-
9478
"""
9579

9680
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
@@ -138,10 +122,9 @@ def plot(
138122
"""Plot results from lr_find run
139123
Args:
140124
suggest: if True, will mark suggested lr to use with a red point
141-
142125
show: if True, will show figure
143-
144126
ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created.
127+
145128
"""
146129
if not _MATPLOTLIB_AVAILABLE:
147130
raise MisconfigurationException(
@@ -190,7 +173,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
190173
191174
"""
192175
losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
193-
losses = losses[torch.isfinite(losses)]
176+
lrs = torch.tensor(self.results["lr"][skip_begin:-skip_end])
177+
is_finite = torch.isfinite(losses)
178+
losses = losses[is_finite]
179+
lrs = lrs[is_finite]
194180

195181
if len(losses) < 2:
196182
# computing torch.gradient requires at least 2 points
@@ -201,12 +187,12 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
201187
self._optimal_idx = None
202188
return None
203189

204-
# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
205-
# incorrectly shifted by an offset
206-
gradients = torch.gradient(losses)[0] # Unpack the tuple
190+
gradients = torch.gradient(losses, spacing=[lrs])[0] # Compute the gradient of losses w.r.t. learning rates
207191
min_grad = torch.argmin(gradients).item()
208-
209-
self._optimal_idx = min_grad + skip_begin
192+
all_losses_idx = torch.arange(len(self.results["loss"]))
193+
idx_non_skipped = all_losses_idx[skip_begin:-skip_end]
194+
idx_finite = idx_non_skipped[is_finite]
195+
self._optimal_idx = idx_finite[min_grad].item() # type: ignore
210196
return self.results["lr"][self._optimal_idx]
211197

212198

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,79 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
538538
suggested_lr = lr_finder.suggestion()
539539
assert math.isfinite(suggested_lr)
540540
assert math.isclose(model.lr, suggested_lr)
541+
542+
543+
def test_gradient_correctness():
544+
"""Test that torch.gradient uses correct spacing parameter."""
545+
lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)
546+
547+
# Synthetic example
548+
lrs = torch.linspace(0, 2 * math.pi, steps=1000)
549+
losses = torch.sin(lrs)
550+
lr_finder.results = {"lr": lrs.tolist(), "loss": losses.tolist()}
551+
552+
# Test the suggestion method
553+
suggestion = lr_finder.suggestion(skip_begin=2, skip_end=2)
554+
assert suggestion is not None
555+
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"
556+
557+
558+
def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
559+
"""Test that exponential and linear modes produce different but valid suggestions.
560+
561+
This verifies that the spacing fix works for both modes and that they behave differently as expected due to their
562+
different lr progressions.
563+
564+
"""
565+
566+
class TestModel(BoringModel):
567+
def __init__(self):
568+
super().__init__()
569+
self.lr = 1e-3
570+
571+
seed_everything(42)
572+
573+
# Test both modes with identical parameters
574+
model_linear = TestModel()
575+
model_exp = TestModel()
576+
577+
trainer_linear = Trainer(default_root_dir=tmp_path, max_epochs=1)
578+
trainer_exp = Trainer(default_root_dir=tmp_path, max_epochs=1)
579+
580+
tuner_linear = Tuner(trainer_linear)
581+
tuner_exp = Tuner(trainer_exp)
582+
583+
lr_finder_linear = tuner_linear.lr_find(model_linear, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="linear")
584+
lr_finder_exp = tuner_exp.lr_find(model_exp, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="exponential")
585+
586+
# Both should produce valid suggestions
587+
suggestion_linear = lr_finder_linear.suggestion()
588+
suggestion_exp = lr_finder_exp.suggestion()
589+
590+
assert suggestion_linear is not None
591+
assert suggestion_exp is not None
592+
assert suggestion_linear > 0
593+
assert suggestion_exp > 0
594+
595+
# Verify that gradient computation uses correct spacing for both modes
596+
for lr_finder, mode in [(lr_finder_linear, "linear"), (lr_finder_exp, "exponential")]:
597+
losses = torch.tensor(lr_finder.results["loss"][10:-10])
598+
lrs = torch.tensor(lr_finder.results["lr"][10:-10])
599+
is_finite = torch.isfinite(losses)
600+
losses_filtered = losses[is_finite]
601+
lrs_filtered = lrs[is_finite]
602+
603+
if len(losses_filtered) >= 2:
604+
# Test that gradient computation works and produces finite results
605+
gradients = torch.gradient(losses_filtered, spacing=[lrs_filtered])[0]
606+
assert torch.isfinite(gradients).all(), f"Non-finite gradients in {mode} mode"
607+
assert len(gradients) == len(losses_filtered)
608+
609+
# Verify gradients with spacing differ from gradients without spacing
610+
gradients_no_spacing = torch.gradient(losses_filtered)[0]
611+
612+
# For exponential mode, these should definitely be different, for linear mode, they might be similar
613+
if mode == "exponential":
614+
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
615+
"Gradients should differ significantly in exponential mode when using proper spacing"
616+
)

0 commit comments

Comments
 (0)