diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 87f75e7d7eb00..28e1a60b4ae4b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187)) --- diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 78d2aa52f5725..67aadf35b1f04 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -178,7 +178,8 @@ def _run_power_scaling( # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not # if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size any_success = False - for _ in range(max_trials): + last_successful_size = new_size + for i in range(max_trials): garbage_collection_cuda() # reset after each try @@ -186,6 +187,13 @@ def _run_power_scaling( try: _try_loop_run(trainer, params) + last_successful_size = new_size # Store the current size before doubling + + # Check if this is the last trial before trying to double + if i + 1 >= max_trials: + new_size = last_successful_size + break + new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") if not changed: @@ -224,6 +232,7 @@ def _run_binary_scaling( low = 1 high = None count = 0 + last_successful_size = new_size while True: garbage_collection_cuda() @@ -233,9 +242,14 @@ def _run_binary_scaling( try: # run loop _try_loop_run(trainer, params) + last_successful_size = new_size # Store the current size before doubling count += 1 - if count > max_trials: + + # Check if we've reached max_trials before trying to adjust batch size + if count >= max_trials: + new_size = last_successful_size break + # Double in size low = new_size if high: diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index f0e5fbe6a3c49..8c50ecd3578a0 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -69,7 +69,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm tuner = Tuner(trainer) new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule) - assert new_batch_size == 16 + assert new_batch_size == 8 if model_bs is not None: assert model.batch_size == new_batch_size @@ -314,7 +314,9 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method, dataset_len = len(trainer.train_dataloader.dataset) assert dataset_len == 64 - assert caplog.text.count("trying batch size") == (max_trials if init_batch_size < dataset_len else 0) + # With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages + expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0 + assert caplog.text.count("trying batch size") == expected_tries assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len) assert trainer.train_dataloader.batch_size == new_batch_size @@ -326,7 +328,8 @@ def test_tuner_with_evaluation_methods(tmp_path, trainer_fn): """Test batch size tuner with Trainer's evaluation methods.""" before_batch_size = 2 max_trials = 4 - expected_scaled_batch_size = before_batch_size ** (max_trials + 1) + # With our fix, we return the last successful batch size, not the doubled untested one + expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32 model = BatchSizeModel(batch_size=before_batch_size) trainer = Trainer(default_root_dir=tmp_path, max_epochs=100) @@ -349,7 +352,8 @@ def test_batch_size_finder_callback(tmp_path, trainer_fn): before_batch_size = 2 max_trials = 4 max_epochs = 2 - expected_scaled_batch_size = before_batch_size ** (max_trials + 1) + # With our fix, we return the last successful batch size, not the doubled untested one + expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32 model = BatchSizeModel(batch_size=before_batch_size) batch_size_finder = BatchSizeFinder(max_trials=max_trials, batch_arg_name="batch_size") @@ -533,3 +537,49 @@ def train_dataloader(self): assert len(scale_checkpoints) == 0, ( f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}" ) + + +class AlwaysSucceedingBoringModel(BoringModel): + """A BoringModel that never fails with OOM errors for batch size scaling tests.""" + + def __init__(self, batch_size=2): + super().__init__() + self.batch_size = batch_size + + +class FailsAtBatchSizeBoringModel(BoringModel): + """A BoringModel that fails when batch size reaches a certain threshold.""" + + def __init__(self, batch_size=2, fail_at=16): + super().__init__() + self.batch_size = batch_size + self.fail_at = fail_at + + def training_step(self, batch, batch_idx): + # Simulate OOM error when batch size is too large + if self.batch_size >= self.fail_at: + raise RuntimeError("CUDA error: out of memory") + return super().training_step(batch, batch_idx) + + +@pytest.mark.parametrize( + ("max_trials", "mode", "init_val", "expected"), + [ + (3, "power", 2, 8), + (3, "binsearch", 2, 8), + (1, "power", 4, 4), + (0, "power", 2, 2), + ], +) +def test_scale_batch_size_max_trials_modes(tmp_path, max_trials, mode, init_val, expected): + model = AlwaysSucceedingBoringModel(batch_size=init_val) + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + tuner = Tuner(trainer) + result = tuner.scale_batch_size( + model, + mode=mode, + steps_per_trial=1, + max_trials=max_trials, + init_val=init_val, + ) + assert result == expected