Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))


- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))

---

## [2.5.4] - 2025-08-29
Expand Down
29 changes: 16 additions & 13 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,27 @@ def _scale_batch_size(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()

new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)

if mode == "power":
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
elif mode == "binsearch":
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
try:
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)

garbage_collection_cuda()
if mode == "power":
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
elif mode == "binsearch":
new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)

log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
garbage_collection_cuda()

__scale_batch_restore_params(trainer, params)
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
except Exception as ex:
raise ex
finally:
__scale_batch_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()
if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)

return new_size

Expand Down
67 changes: 36 additions & 31 deletions src/lightning/pytorch/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,40 +257,45 @@ def _lr_find(
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)

# Fit, lr & loss logged in callback
_try_loop_run(trainer, params)

# Prompt if we stopped early
if trainer.global_step != num_training + start_steps:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

# Transfer results from callback to lr finder object
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

__lr_finder_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update results across ranks
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)

# Restore initial state of model (this will also restore the original optimizer state)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()
lr_finder_finished = False
try:
# Configure optimizer and scheduler
lr_finder._exchange_scheduler(trainer)

# Fit, lr & loss logged in callback
_try_loop_run(trainer, params)

# Prompt if we stopped early
if trainer.global_step != num_training + start_steps:
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")

# Transfer results from callback to lr finder object
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose

__lr_finder_restore_params(trainer, params)

if trainer.progress_bar_callback:
trainer.progress_bar_callback.enable()

# Update results across ranks
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
lr_finder_finished = True
except Exception as ex:
raise ex
finally:
# Restore initial state of model (this will also restore the original optimizer state)
trainer._checkpoint_connector.restore(ckpt_path)
trainer.strategy.remove_checkpoint(ckpt_path)
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
trainer.fit_loop.setup_data()

# Apply LR suggestion after restoring so it persists for the real training run
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
if update_attr:
if update_attr and lr_finder_finished:
lr = lr_finder.suggestion()
if lr is not None:
# update the attribute on the LightningModule (e.g., lr or learning_rate)
Expand Down
50 changes: 50 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import math
import os
Expand Down Expand Up @@ -750,3 +751,52 @@ def __init__(self):
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
"Gradients should differ significantly in exponential mode when using proper spacing"
)


def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
"""Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""

class FailingModel(BoringModel):
def __init__(self, fail_on_step=2):
super().__init__()
self.fail_on_step = fail_on_step
self.current_step = 0
self.learning_rate = 1e-3

def training_step(self, batch, batch_idx):
self.current_step += 1
if self.current_step >= self.fail_on_step:
raise RuntimeError("Intentional failure for testing cleanup")
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

model = FailingModel()
lr_finder = LearningRateFinder(num_training_steps=5)

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
callbacks=[lr_finder],
)

# Check no lr_find checkpoint files exist initially
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"

# Run lr finder and expect it to fail
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
trainer.fit(model)

# Check that no lr_find checkpoint files are left behind
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
assert len(lr_find_checkpoints) == 0, (
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
)
47 changes: 47 additions & 0 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
from copy import deepcopy
Expand Down Expand Up @@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path):

assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != steps_per_trial


def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""

class FailingModel(BoringModel):
def __init__(self, fail_on_step=2):
super().__init__()
self.fail_on_step = fail_on_step
self.current_step = 0
self.batch_size = 2

def training_step(self, batch, batch_idx):
self.current_step += 1
if self.current_step >= self.fail_on_step:
raise RuntimeError("Intentional failure for testing cleanup")
return super().training_step(batch, batch_idx)

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)

model = FailingModel()
batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
callbacks=[batch_size_finder],
)

# Check no scale_batch_size checkpoint files exist initially
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"

# Run batch size scaler and expect it to fail
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
trainer.fit(model)

# Check that no scale_batch_size checkpoint files are left behind
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
assert len(scale_checkpoints) == 0, (
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
)
Loading