Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))


- Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055))


- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051))


Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
self._attr_name = attr_name

self._early_exit = False
self.lr_finder: Optional[_LRFinder] = None
self.optimal_lr: Optional[_LRFinder] = None

def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
with isolate_rng():
Expand Down
34 changes: 10 additions & 24 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,10 @@ class _LRFinder:

Args:
mode: either `linear` or `exponential`, how to increase lr after each step

lr_min: lr to start search from

lr_max: lr to stop search

num_training: number of steps to take between lr_min and lr_max

Example::
# Run lr finder
lr_finder = trainer.lr_find(model)

# Results stored in
lr_finder.results

# Plot using
lr_finder.plot()

# Get suggestion
lr = lr_finder.suggestion()

"""

def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
Expand Down Expand Up @@ -138,10 +122,9 @@ def plot(
"""Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point

show: if True, will show figure

ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created.

"""
if not _MATPLOTLIB_AVAILABLE:
raise MisconfigurationException(
Expand Down Expand Up @@ -190,7 +173,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]

"""
losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
losses = losses[torch.isfinite(losses)]
lrs = torch.tensor(self.results["lr"][skip_begin:-skip_end])
is_finite = torch.isfinite(losses)
losses = losses[is_finite]
lrs = lrs[is_finite]

if len(losses) < 2:
# computing torch.gradient requires at least 2 points
Expand All @@ -201,12 +187,12 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
self._optimal_idx = None
return None

# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
# incorrectly shifted by an offset
gradients = torch.gradient(losses)[0] # Unpack the tuple
gradients = torch.gradient(losses, spacing=[lrs])[0] # Compute the gradient of losses w.r.t. learning rates
min_grad = torch.argmin(gradients).item()

self._optimal_idx = min_grad + skip_begin
all_losses_idx = torch.arange(len(self.results["loss"]))
idx_non_skipped = all_losses_idx[skip_begin:-skip_end]
idx_finite = idx_non_skipped[is_finite]
self._optimal_idx = idx_finite[min_grad].item() # type: ignore
return self.results["lr"][self._optimal_idx]


Expand Down
76 changes: 76 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,79 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
suggested_lr = lr_finder.suggestion()
assert math.isfinite(suggested_lr)
assert math.isclose(model.lr, suggested_lr)


def test_gradient_correctness():
"""Test that torch.gradient uses correct spacing parameter."""
lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)

# Synthetic example
lrs = torch.linspace(0, 2 * math.pi, steps=1000)
losses = torch.sin(lrs)
lr_finder.results = {"lr": lrs.tolist(), "loss": losses.tolist()}

# Test the suggestion method
suggestion = lr_finder.suggestion(skip_begin=2, skip_end=2)
assert suggestion is not None
assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example"


def test_exponential_vs_linear_mode_gradient_difference(tmp_path):
"""Test that exponential and linear modes produce different but valid suggestions.

This verifies that the spacing fix works for both modes and that they behave differently as expected due to their
different lr progressions.

"""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.lr = 1e-3

seed_everything(42)

# Test both modes with identical parameters
model_linear = TestModel()
model_exp = TestModel()

trainer_linear = Trainer(default_root_dir=tmp_path, max_epochs=1)
trainer_exp = Trainer(default_root_dir=tmp_path, max_epochs=1)

tuner_linear = Tuner(trainer_linear)
tuner_exp = Tuner(trainer_exp)

lr_finder_linear = tuner_linear.lr_find(model_linear, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="linear")
lr_finder_exp = tuner_exp.lr_find(model_exp, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="exponential")

# Both should produce valid suggestions
suggestion_linear = lr_finder_linear.suggestion()
suggestion_exp = lr_finder_exp.suggestion()

assert suggestion_linear is not None
assert suggestion_exp is not None
assert suggestion_linear > 0
assert suggestion_exp > 0

# Verify that gradient computation uses correct spacing for both modes
for lr_finder, mode in [(lr_finder_linear, "linear"), (lr_finder_exp, "exponential")]:
losses = torch.tensor(lr_finder.results["loss"][10:-10])
lrs = torch.tensor(lr_finder.results["lr"][10:-10])
is_finite = torch.isfinite(losses)
losses_filtered = losses[is_finite]
lrs_filtered = lrs[is_finite]

if len(losses_filtered) >= 2:
# Test that gradient computation works and produces finite results
gradients = torch.gradient(losses_filtered, spacing=[lrs_filtered])[0]
assert torch.isfinite(gradients).all(), f"Non-finite gradients in {mode} mode"
assert len(gradients) == len(losses_filtered)

# Verify gradients with spacing differ from gradients without spacing
gradients_no_spacing = torch.gradient(losses_filtered)[0]

# For exponential mode, these should definitely be different, for linear mode, they might be similar
if mode == "exponential":
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
"Gradients should differ significantly in exponential mode when using proper spacing"
)
Loading