Skip to content

Commit 03101ce

Browse files
BoringDonutOleksandra SokolBorda
authored andcommitted
Bugfix/18394 batch size finder max val batches (#18854)
Co-authored-by: Oleksandra Sokol <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit e50b68a)
1 parent ce82dc0 commit 03101ce

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3232
- Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793))
3333

3434

35+
- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))
36+
37+
3538

3639
## [2.1.0] - 2023-10-11
3740

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None:
323323
assert loop is not None
324324
loop._combined_loader = None # force a reload
325325
loop.setup_data()
326+
if isinstance(loop, pl.loops._FitLoop):
327+
loop.epoch_loop.val_loop._combined_loader = None
328+
loop.epoch_loop.val_loop.setup_data()
326329

327330

328331
def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
317317
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
318318

319319
assert trainer.train_dataloader.batch_size == new_batch_size
320-
assert trainer.val_dataloaders.batch_size == init_batch_size
320+
assert trainer.val_dataloaders.batch_size == new_batch_size
321321

322322

323323
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
@@ -469,3 +469,20 @@ def train_dataloader(self):
469469
assert new_batch_size == model.batch_size
470470
assert new_batch_size == expected_batch_size
471471
assert trainer.train_dataloader.batch_size == expected_batch_size
472+
473+
474+
def test_batch_size_finder_callback_val_batches(tmpdir):
475+
"""Test that `BatchSizeFinder` does not limit the number of val batches during training."""
476+
steps_per_trial = 2
477+
model = BatchSizeModel(batch_size=16)
478+
trainer = Trainer(
479+
default_root_dir=tmpdir,
480+
num_sanity_val_steps=0,
481+
max_epochs=1,
482+
enable_model_summary=False,
483+
callbacks=[BatchSizeFinder(steps_per_trial=steps_per_trial, max_trials=1)],
484+
)
485+
trainer.fit(model)
486+
487+
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
488+
assert trainer.num_val_batches[0] != steps_per_trial

0 commit comments

Comments
 (0)