Skip to content

Commit 4eaa87b

Browse files
committed
fix tests
1 parent dce919a commit 4eaa87b

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def _run_binsearch_scaling(
313313
f"Applying margin of {margin:.1%}, reducing batch size from {new_size} to {margin_reduced_size}"
314314
)
315315
new_size = margin_reduced_size
316+
# update attribute in the model/datamodule as well
317+
lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
316318

317319
return new_size
318320

@@ -353,7 +355,10 @@ def _adjust_batch_size(
353355
try:
354356
combined_dataset_length = combined_loader._dataset_length()
355357
if batch_size >= combined_dataset_length:
356-
rank_zero_info(f"The batch size {batch_size} is greater or equal than the length of your dataset.")
358+
rank_zero_info(
359+
f"The batch size {batch_size} is greater or equal than"
360+
f" the length of your dataset: {combined_dataset_length}."
361+
)
357362
return batch_size, False
358363
except NotImplementedError:
359364
# all datasets are iterable style

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import glob
1515
import logging
16+
import math
1617
import os
1718
from copy import deepcopy
1819
from unittest.mock import patch
@@ -69,7 +70,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm
6970

7071
tuner = Tuner(trainer)
7172
new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule)
72-
assert new_batch_size == 8
73+
assert new_batch_size == 7 # applied margin of 5% on 8 -> int(8 * 0.95) = 7
7374

7475
if model_bs is not None:
7576
assert model.batch_size == new_batch_size
@@ -317,7 +318,12 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method,
317318
# With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages
318319
expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0
319320
assert caplog.text.count("trying batch size") == expected_tries
320-
assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len)
321+
322+
# Determine the largest batch size that was actually tested.
323+
# For "power" this is the final found size; for "binsearch" we applied a 5% margin
324+
# when storing the final value, so the largest tested value is the one before applying margin.
325+
largest_tested_batch_size = new_batch_size if scale_method == "power" else int(math.ceil(new_batch_size * 100 / 95))
326+
assert caplog.text.count("greater or equal than the length") == int(largest_tested_batch_size >= dataset_len)
321327

322328
assert trainer.train_dataloader.batch_size == new_batch_size
323329
assert trainer.val_dataloaders.batch_size == new_batch_size
@@ -453,7 +459,7 @@ def val_dataloader(self):
453459
tuner.scale_batch_size(model, method="validate")
454460

455461

456-
@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 100)])
462+
@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 95)])
457463
@patch("lightning.pytorch.tuner.batch_size_scaling.is_oom_error", return_value=True)
458464
def test_dataloader_batch_size_updated_on_failure(_, tmp_path, scale_method, expected_batch_size):
459465
class CustomBatchSizeModel(BatchSizeModel):
@@ -611,7 +617,7 @@ def training_step(self, batch, batch_idx):
611617
("max_trials", "mode", "init_val", "expected"),
612618
[
613619
(3, "power", 2, 8),
614-
(3, "binsearch", 2, 8),
620+
(3, "binsearch", 2, 7), # applied margin of 5% on 8 -> int(8 * 0.95) = 7
615621
(1, "power", 4, 4),
616622
(0, "power", 2, 2),
617623
],

0 commit comments

Comments
 (0)