Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))


- Fixed integration between `LearningRateFinder` and `EarlyStopping` ([#21056](https://github.com/Lightning-AI/pytorch-lightning/pull/21056))


- Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055))


Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def _lr_find(
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None

trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()
return lr_finder


Expand Down
62 changes: 62 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning_utilities.test.warning import no_warning_call

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.tuner.lr_finder import _LRFinder
Expand Down Expand Up @@ -540,6 +541,67 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
assert math.isclose(model.lr, suggested_lr)


def test_lr_finder_with_early_stopping(tmp_path):
class ModelWithValidation(BoringModel):
def __init__(self):
super().__init__()
self.learning_rate = 0.1

def validation_step(self, batch, batch_idx):
output = self.step(batch)
# Log validation loss that EarlyStopping will monitor
self.log("val_loss", output, on_epoch=True)
return output

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

# Add ReduceLROnPlateau scheduler that monitors val_loss (issue #20355)
plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=2
)
scheduler_config = {"scheduler": plateau_scheduler, "interval": "epoch", "monitor": "val_loss"}

return {"optimizer": optimizer, "lr_scheduler": scheduler_config}

model = ModelWithValidation()

# Both callbacks that previously caused issues
callbacks = [
LearningRateFinder(num_training_steps=100, update_attr=False),
EarlyStopping(monitor="val_loss", patience=3),
]

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=10,
callbacks=callbacks,
limit_train_batches=5,
limit_val_batches=3,
enable_model_summary=False,
enable_progress_bar=False,
)

trainer.fit(model)
assert trainer.state.finished

# Verify that both callbacks were active
lr_finder_callback = None
early_stopping_callback = None
for callback in trainer.callbacks:
if isinstance(callback, LearningRateFinder):
lr_finder_callback = callback
elif isinstance(callback, EarlyStopping):
early_stopping_callback = callback

assert lr_finder_callback is not None, "LearningRateFinder callback should be present"
assert early_stopping_callback is not None, "EarlyStopping callback should be present"

# Verify learning rate finder ran and has results
assert lr_finder_callback.optimal_lr is not None, "Learning rate finder should have results"
assert lr_finder_callback.optimal_lr.suggestion() > 0, "Learning rate suggestion should be positive"


def test_gradient_correctness():
"""Test that torch.gradient uses correct spacing parameter."""
lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)
Expand Down
Loading