Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))


- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))

---

## [2.5.4] - 2025-08-29
Expand Down
29 changes: 16 additions & 13 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,27 @@ def _scale_batch_size(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()

new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)

if mode == "power":
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
elif mode == "binsearch":
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
try:
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)

garbage_collection_cuda()
if mode == "power":
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
elif mode == "binsearch":
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)

log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
garbage_collection_cuda()

__scale_batch_restore_params(trainer, params)
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
except Exception as e:
raise e
finally:
__scale_batch_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)

return new_size

Expand Down
67 changes: 36 additions & 31 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,40 +257,45 @@ def _lr_find(
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)

# Fit, lr & loss logged in callback
_try_loop_run(trainer, params)

# Prompt if we stopped early
if trainer.global_step != num_training + start_steps:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

# Transfer results from callback to lr finder object
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

__lr_finder_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update results across ranks
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)

# Restore initial state of model (this will also restore the original optimizer state)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()
lr_finder_finished = False
try:
# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)

# Fit, lr & loss logged in callback
_try_loop_run(trainer, params)

# Prompt if we stopped early
if trainer.global_step != num_training + start_steps:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

# Transfer results from callback to lr finder object
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

__lr_finder_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update results across ranks
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
lr_finder_finished = True
except Exception as e:
raise e
finally:
# Restore initial state of model (this will also restore the original optimizer state)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()

# Apply LR suggestion after restoring so it persists for the real training run
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
if update_attr:
if update_attr and lr_finder_finished:
lr = lr_finder.suggestion()
if lr is not None:
# update the attribute on the LightningModule (e.g., lr or learning_rate)
Expand Down
50 changes: 50 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import math
import os
Expand Down Expand Up @@ -750,3 +751,52 @@ def __init__(self):
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
"Gradients should differ significantly in exponential mode when using proper spacing"
)


def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
"""Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""

class FailingModel(BoringModel):
def __init__(self, fail_on_step=2):
super().__init__()
self.fail_on_step = fail_on_step
self.current_step = 0
self.learning_rate = 1e-3

def training_step(self, batch, batch_idx):
self.current_step += 1
if self.current_step >= self.fail_on_step:
raise RuntimeError("Intentional failure for testing cleanup")
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

model = FailingModel()
lr_finder = LearningRateFinder(num_training_steps=5)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
callbacks=[lr_finder],
)

# Check no lr_find checkpoint files exist initially
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"

# Run lr finder and expect it to fail
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
trainer.fit(model)

# Check that no lr_find checkpoint files are left behind
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
assert len(lr_find_checkpoints) == 0, (
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
)
47 changes: 47 additions & 0 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
from copy import deepcopy
Expand Down Expand Up @@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path):

assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != steps_per_trial


def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""

class FailingModel(BoringModel):
def __init__(self, fail_on_step=2):
super().__init__()
self.fail_on_step = fail_on_step
self.current_step = 0
self.batch_size = 2

def training_step(self, batch, batch_idx):
self.current_step += 1
if self.current_step >= self.fail_on_step:
raise RuntimeError("Intentional failure for testing cleanup")
return super().training_step(batch, batch_idx)

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)

model = FailingModel()
batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
callbacks=[batch_size_finder],
)

# Check no scale_batch_size checkpoint files exist initially
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"

# Run batch size scaler and expect it to fail
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
trainer.fit(model)

# Check that no scale_batch_size checkpoint files are left behind
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
assert len(scale_checkpoints) == 0, (
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
)
Loading