diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index aaa9e640a8110..f90bffe788dbb 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) +- Fixed integration between `LearningRateFinder` and `EarlyStopping` ([#21056](https://github.com/Lightning-AI/pytorch-lightning/pull/21056)) + + - Fix gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055)) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 35f7baa1408c9..a5d758f7fff19 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -292,7 +292,8 @@ def _lr_find( trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True trainer.fit_loop.epoch_loop.val_loop._combined_loader = None - + trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit + trainer.fit_loop.setup_data() return lr_finder diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index d99d15321bd75..442c795491320 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -23,6 +23,7 @@ from lightning_utilities.test.warning import no_warning_call from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.tuner.lr_finder import _LRFinder @@ -540,6 +541,67 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: assert math.isclose(model.lr, suggested_lr) +def test_lr_finder_with_early_stopping(tmp_path): + class ModelWithValidation(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.1 + + def validation_step(self, batch, batch_idx): + output = self.step(batch) + # Log validation loss that EarlyStopping will monitor + self.log("val_loss", output, on_epoch=True) + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + # Add ReduceLROnPlateau scheduler that monitors val_loss (issue #20355) + plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=2 + ) + scheduler_config = {"scheduler": plateau_scheduler, "interval": "epoch", "monitor": "val_loss"} + + return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + + model = ModelWithValidation() + + # Both callbacks that previously caused issues + callbacks = [ + LearningRateFinder(num_training_steps=100, update_attr=False), + EarlyStopping(monitor="val_loss", patience=3), + ] + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=10, + callbacks=callbacks, + limit_train_batches=5, + limit_val_batches=3, + enable_model_summary=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + assert trainer.state.finished + + # Verify that both callbacks were active + lr_finder_callback = None + early_stopping_callback = None + for callback in trainer.callbacks: + if isinstance(callback, LearningRateFinder): + lr_finder_callback = callback + elif isinstance(callback, EarlyStopping): + early_stopping_callback = callback + + assert lr_finder_callback is not None, "LearningRateFinder callback should be present" + assert early_stopping_callback is not None, "EarlyStopping callback should be present" + + # Verify learning rate finder ran and has results + assert lr_finder_callback.optimal_lr is not None, "Learning rate finder should have results" + assert lr_finder_callback.optimal_lr.suggestion() > 0, "Learning rate suggestion should be positive" + + def test_gradient_correctness(): """Test that torch.gradient uses correct spacing parameter.""" lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20)