diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3c22a1b1f11f..2e8390bca761c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147)) + +- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162)) + --- ## [2.5.4] - 2025-08-29 diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 99badd84bb8ad..78d2aa52f5725 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -76,24 +76,27 @@ def _scale_batch_size( if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) - - if mode == "power": - new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) - elif mode == "binsearch": - new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) + try: + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) - garbage_collection_cuda() + if mode == "power": + new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) + elif mode == "binsearch": + new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) - log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") + garbage_collection_cuda() - __scale_batch_restore_params(trainer, params) + log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") + except Exception as ex: + raise ex + finally: + __scale_batch_restore_params(trainer, params) - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) return new_size diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index b4b61d5cf0f93..5ef35dcd6d992 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -257,40 +257,45 @@ def _lr_find( # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - # Configure optimizer and scheduler - lr_finder._exchange_scheduler(trainer) - - # Fit, lr & loss logged in callback - _try_loop_run(trainer, params) - - # Prompt if we stopped early - if trainer.global_step != num_training + start_steps: - log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") - - # Transfer results from callback to lr finder object - lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) - lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose - - __lr_finder_restore_params(trainer, params) - - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() - - # Update results across ranks - lr_finder.results = trainer.strategy.broadcast(lr_finder.results) - - # Restore initial state of model (this will also restore the original optimizer state) - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) - trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True - trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True - trainer.fit_loop.epoch_loop.val_loop._combined_loader = None - trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit - trainer.fit_loop.setup_data() + lr_finder_finished = False + try: + # Configure optimizer and scheduler + lr_finder._exchange_scheduler(trainer) + + # Fit, lr & loss logged in callback + _try_loop_run(trainer, params) + + # Prompt if we stopped early + if trainer.global_step != num_training + start_steps: + log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") + + # Transfer results from callback to lr finder object + lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) + lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose + + __lr_finder_restore_params(trainer, params) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Update results across ranks + lr_finder.results = trainer.strategy.broadcast(lr_finder.results) + lr_finder_finished = True + except Exception as ex: + raise ex + finally: + # Restore initial state of model (this will also restore the original optimizer state) + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.val_loop._combined_loader = None + trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit + trainer.fit_loop.setup_data() # Apply LR suggestion after restoring so it persists for the real training run # When used as a callback, the suggestion would otherwise be lost due to checkpoint restore - if update_attr: + if update_attr and lr_finder_finished: lr = lr_finder.suggestion() if lr is not None: # update the attribute on the LightningModule (e.g., lr or learning_rate) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 69575a351b0a5..81352ebe256ef 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import logging import math import os @@ -750,3 +751,52 @@ def __init__(self): assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), ( "Gradients should differ significantly in exponential mode when using proper spacing" ) + + +def test_lr_finder_checkpoint_cleanup_on_error(tmp_path): + """Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding.""" + + class FailingModel(BoringModel): + def __init__(self, fail_on_step=2): + super().__init__() + self.fail_on_step = fail_on_step + self.current_step = 0 + self.learning_rate = 1e-3 + + def training_step(self, batch, batch_idx): + self.current_step += 1 + if self.current_step >= self.fail_on_step: + raise RuntimeError("Intentional failure for testing cleanup") + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = FailingModel() + lr_finder = LearningRateFinder(num_training_steps=5) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + callbacks=[lr_finder], + ) + + # Check no lr_find checkpoint files exist initially + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) + assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially" + + # Run lr finder and expect it to fail + with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"): + trainer.fit(model) + + # Check that no lr_find checkpoint files are left behind + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) + assert len(lr_find_checkpoints) == 0, ( + f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}" + ) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index e4ed533c6fa83..f0e5fbe6a3c49 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import logging import os from copy import deepcopy @@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path): assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) assert trainer.num_val_batches[0] != steps_per_trial + + +def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path): + """Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling.""" + + class FailingModel(BoringModel): + def __init__(self, fail_on_step=2): + super().__init__() + self.fail_on_step = fail_on_step + self.current_step = 0 + self.batch_size = 2 + + def training_step(self, batch, batch_idx): + self.current_step += 1 + if self.current_step >= self.fail_on_step: + raise RuntimeError("Intentional failure for testing cleanup") + return super().training_step(batch, batch_idx) + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size) + + model = FailingModel() + batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + callbacks=[batch_size_finder], + ) + + # Check no scale_batch_size checkpoint files exist initially + scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt")) + assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially" + + # Run batch size scaler and expect it to fail + with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"): + trainer.fit(model) + + # Check that no scale_batch_size checkpoint files are left behind + scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt")) + assert len(scale_checkpoints) == 0, ( + f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}" + )