Skip to content

Commit acdfa8a

Browse files
authored
Merge branch 'master' into batch_size_scaler_newargs
2 parents 03d3e04 + e1e2534 commit acdfa8a

File tree

8 files changed

+165
-45
lines changed

8 files changed

+165
-45
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
3131

3232

33+
- Respect `verbose=False` in `seed_everything` when no seed is provided
34+
35+
3336
---
3437

3538
## [2.5.4] - 2025-08-29

src/lightning/fabric/utilities/seed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
4040
env_seed = os.environ.get("PL_GLOBAL_SEED")
4141
if env_seed is None:
4242
seed = 0
43-
rank_zero_warn(f"No seed found, seed set to {seed}")
43+
if verbose:
44+
rank_zero_warn(f"No seed found, seed set to {seed}")
4445
else:
4546
try:
4647
seed = int(env_seed)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545

4646
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
4747

48+
49+
- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))
50+
4851
---
4952

5053
## [2.5.4] - 2025-08-29

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,27 @@ def _scale_batch_size(
8282
if trainer.progress_bar_callback:
8383
trainer.progress_bar_callback.disable()
8484

85-
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
86-
87-
if mode == "power":
88-
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val)
89-
elif mode == "binsearch":
90-
new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val)
85+
try:
86+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
9187

92-
garbage_collection_cuda()
88+
if mode == "power":
89+
new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params, max_val)
90+
elif mode == "binsearch":
91+
new_size = _run_binsearch_scaling(trainer, new_size, batch_arg_name, max_trials, params, margin, max_val)
9392

94-
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
93+
garbage_collection_cuda()
9594

96-
__scale_batch_restore_params(trainer, params)
95+
log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
96+
except Exception as ex:
97+
raise ex
98+
finally:
99+
__scale_batch_restore_params(trainer, params)
97100

98-
if trainer.progress_bar_callback:
99-
trainer.progress_bar_callback.enable()
101+
if trainer.progress_bar_callback:
102+
trainer.progress_bar_callback.enable()
100103

101-
trainer._checkpoint_connector.restore(ckpt_path)
102-
trainer.strategy.remove_checkpoint(ckpt_path)
104+
trainer._checkpoint_connector.restore(ckpt_path)
105+
trainer.strategy.remove_checkpoint(ckpt_path)
103106

104107
return new_size
105108

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -257,40 +257,45 @@ def _lr_find(
257257
# Initialize lr finder object (stores results)
258258
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
259259

260-
# Configure optimizer and scheduler
261-
lr_finder._exchange_scheduler(trainer)
262-
263-
# Fit, lr & loss logged in callback
264-
_try_loop_run(trainer, params)
265-
266-
# Prompt if we stopped early
267-
if trainer.global_step != num_training + start_steps:
268-
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
269-
270-
# Transfer results from callback to lr finder object
271-
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
272-
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
273-
274-
__lr_finder_restore_params(trainer, params)
275-
276-
if trainer.progress_bar_callback:
277-
trainer.progress_bar_callback.enable()
278-
279-
# Update results across ranks
280-
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
281-
282-
# Restore initial state of model (this will also restore the original optimizer state)
283-
trainer._checkpoint_connector.restore(ckpt_path)
284-
trainer.strategy.remove_checkpoint(ckpt_path)
285-
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
286-
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
287-
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
288-
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
289-
trainer.fit_loop.setup_data()
260+
lr_finder_finished = False
261+
try:
262+
# Configure optimizer and scheduler
263+
lr_finder._exchange_scheduler(trainer)
264+
265+
# Fit, lr & loss logged in callback
266+
_try_loop_run(trainer, params)
267+
268+
# Prompt if we stopped early
269+
if trainer.global_step != num_training + start_steps:
270+
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
271+
272+
# Transfer results from callback to lr finder object
273+
lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
274+
lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
275+
276+
__lr_finder_restore_params(trainer, params)
277+
278+
if trainer.progress_bar_callback:
279+
trainer.progress_bar_callback.enable()
280+
281+
# Update results across ranks
282+
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
283+
lr_finder_finished = True
284+
except Exception as ex:
285+
raise ex
286+
finally:
287+
# Restore initial state of model (this will also restore the original optimizer state)
288+
trainer._checkpoint_connector.restore(ckpt_path)
289+
trainer.strategy.remove_checkpoint(ckpt_path)
290+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
291+
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
292+
trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
293+
trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
294+
trainer.fit_loop.setup_data()
290295

291296
# Apply LR suggestion after restoring so it persists for the real training run
292297
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
293-
if update_attr:
298+
if update_attr and lr_finder_finished:
294299
lr = lr_finder.suggestion()
295300
if lr is not None:
296301
# update the attribute on the LightningModule (e.g., lr or learning_rate)

tests/tests_fabric/utilities/test_seed.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def test_seed_everything_accepts_valid_seed_from_env():
7272
assert seed_everything() == 17
7373

7474

75+
@mock.patch.dict(os.environ, {}, clear=True)
76+
def test_seed_everything_non_verbose_no_warning():
77+
"""Ensure that no warning is emitted when verbose is False and no seed is provided."""
78+
with warnings.catch_warnings(record=True) as caught:
79+
seed_everything(verbose=False)
80+
assert caught == []
81+
82+
7583
def test_reset_seed_no_op():
7684
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
7785
assert "PL_GLOBAL_SEED" not in os.environ

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import math
1617
import os
@@ -750,3 +751,52 @@ def __init__(self):
750751
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
751752
"Gradients should differ significantly in exponential mode when using proper spacing"
752753
)
754+
755+
756+
def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
757+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""
758+
759+
class FailingModel(BoringModel):
760+
def __init__(self, fail_on_step=2):
761+
super().__init__()
762+
self.fail_on_step = fail_on_step
763+
self.current_step = 0
764+
self.learning_rate = 1e-3
765+
766+
def training_step(self, batch, batch_idx):
767+
self.current_step += 1
768+
if self.current_step >= self.fail_on_step:
769+
raise RuntimeError("Intentional failure for testing cleanup")
770+
return super().training_step(batch, batch_idx)
771+
772+
def configure_optimizers(self):
773+
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
774+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
775+
return [optimizer], [lr_scheduler]
776+
777+
model = FailingModel()
778+
lr_finder = LearningRateFinder(num_training_steps=5)
779+
780+
trainer = Trainer(
781+
default_root_dir=tmp_path,
782+
max_epochs=1,
783+
enable_checkpointing=False,
784+
enable_progress_bar=False,
785+
enable_model_summary=False,
786+
logger=False,
787+
callbacks=[lr_finder],
788+
)
789+
790+
# Check no lr_find checkpoint files exist initially
791+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
792+
assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"
793+
794+
# Run lr finder and expect it to fail
795+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
796+
trainer.fit(model)
797+
798+
# Check that no lr_find checkpoint files are left behind
799+
lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
800+
assert len(lr_find_checkpoints) == 0, (
801+
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
802+
)

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import glob
1415
import logging
1516
import os
1617
from copy import deepcopy
@@ -531,3 +532,49 @@ def test_scale_batch_size_max_val_limit(tmp_path, mode):
531532

532533
assert result is not None
533534
assert result <= max_val
535+
536+
537+
def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
538+
"""Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""
539+
540+
class FailingModel(BoringModel):
541+
def __init__(self, fail_on_step=2):
542+
super().__init__()
543+
self.fail_on_step = fail_on_step
544+
self.current_step = 0
545+
self.batch_size = 2
546+
547+
def training_step(self, batch, batch_idx):
548+
self.current_step += 1
549+
if self.current_step >= self.fail_on_step:
550+
raise RuntimeError("Intentional failure for testing cleanup")
551+
return super().training_step(batch, batch_idx)
552+
553+
def train_dataloader(self):
554+
return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)
555+
556+
model = FailingModel()
557+
batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
558+
trainer = Trainer(
559+
default_root_dir=tmp_path,
560+
max_epochs=1,
561+
enable_checkpointing=False,
562+
enable_progress_bar=False,
563+
enable_model_summary=False,
564+
logger=False,
565+
callbacks=[batch_size_finder],
566+
)
567+
568+
# Check no scale_batch_size checkpoint files exist initially
569+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
570+
assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"
571+
572+
# Run batch size scaler and expect it to fail
573+
with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
574+
trainer.fit(model)
575+
576+
# Check that no scale_batch_size checkpoint files are left behind
577+
scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
578+
assert len(scale_checkpoints) == 0, (
579+
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
580+
)

0 commit comments

Comments
 (0)