Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))


---
Expand Down
18 changes: 16 additions & 2 deletions src/lightning/pytorch/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,22 @@ def _run_power_scaling(
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
any_success = False
for _ in range(max_trials):
last_successful_size = new_size
for i in range(max_trials):
garbage_collection_cuda()

# reset after each try
_reset_progress(trainer)

try:
_try_loop_run(trainer, params)
last_successful_size = new_size # Store the current size before doubling

# Check if this is the last trial before trying to double
if i + 1 >= max_trials:
new_size = last_successful_size
break

new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")

if not changed:
Expand Down Expand Up @@ -224,6 +232,7 @@ def _run_binary_scaling(
low = 1
high = None
count = 0
last_successful_size = new_size
while True:
garbage_collection_cuda()

Expand All @@ -233,9 +242,14 @@ def _run_binary_scaling(
try:
# run loop
_try_loop_run(trainer, params)
last_successful_size = new_size # Store the current size before doubling
count += 1
if count > max_trials:

# Check if we've reached max_trials before trying to adjust batch size
if count >= max_trials:
new_size = last_successful_size
break

# Double in size
low = new_size
if high:
Expand Down
58 changes: 54 additions & 4 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm

tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
assert new_batch_size == 16
assert new_batch_size == 8

if model_bs is not None:
assert model.batch_size == new_batch_size
Expand Down Expand Up @@ -314,7 +314,9 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,

dataset_len = len(trainer.train_dataloader.dataset)
assert dataset_len == 64
assert caplog.text.count("trying batch size") == (max_trials if init_batch_size < dataset_len else 0)
# With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages
expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0
assert caplog.text.count("trying batch size") == expected_tries
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)

assert trainer.train_dataloader.batch_size == new_batch_size
Expand All @@ -326,7 +328,8 @@ def test_tuner_with_evaluation_methods(tmp_path, trainer_fn):
"""Test batch size tuner with Trainer's evaluation methods."""
before_batch_size = 2
max_trials = 4
expected_scaled_batch_size = before_batch_size ** (max_trials + 1)
# With our fix, we return the last successful batch size, not the doubled untested one
expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32

model = BatchSizeModel(batch_size=before_batch_size)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=100)
Expand All @@ -349,7 +352,8 @@ def test_batch_size_finder_callback(tmp_path, trainer_fn):
before_batch_size = 2
max_trials = 4
max_epochs = 2
expected_scaled_batch_size = before_batch_size ** (max_trials + 1)
# With our fix, we return the last successful batch size, not the doubled untested one
expected_scaled_batch_size = before_batch_size**max_trials # 2^4 = 16, not 2^5 = 32

model = BatchSizeModel(batch_size=before_batch_size)
batch_size_finder = BatchSizeFinder(max_trials=max_trials, batch_arg_name="batch_size")
Expand Down Expand Up @@ -533,3 +537,49 @@ def train_dataloader(self):
assert len(scale_checkpoints) == 0, (
f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
)


class AlwaysSucceedingBoringModel(BoringModel):
"""A BoringModel that never fails with OOM errors for batch size scaling tests."""

def __init__(self, batch_size=2):
super().__init__()
self.batch_size = batch_size


class FailsAtBatchSizeBoringModel(BoringModel):
"""A BoringModel that fails when batch size reaches a certain threshold."""

def __init__(self, batch_size=2, fail_at=16):
super().__init__()
self.batch_size = batch_size
self.fail_at = fail_at

def training_step(self, batch, batch_idx):
# Simulate OOM error when batch size is too large
if self.batch_size >= self.fail_at:
raise RuntimeError("CUDA error: out of memory")
return super().training_step(batch, batch_idx)


@pytest.mark.parametrize(
("max_trials", "mode", "init_val", "expected"),
[
(3, "power", 2, 8),
(3, "binsearch", 2, 8),
(1, "power", 4, 4),
(0, "power", 2, 2),
],
)
def test_scale_batch_size_max_trials_modes(tmp_path, max_trials, mode, init_val, expected):
model = AlwaysSucceedingBoringModel(batch_size=init_val)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
tuner = Tuner(trainer)
result = tuner.scale_batch_size(
model,
mode=mode,
steps_per_trial=1,
max_trials=max_trials,
init_val=init_val,
)
assert result == expected
Loading