|
13 | 13 | # limitations under the License. |
14 | 14 | import glob |
15 | 15 | import logging |
| 16 | +import math |
16 | 17 | import os |
17 | 18 | from copy import deepcopy |
18 | 19 | from unittest.mock import patch |
@@ -69,7 +70,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmp_path, model_bs, dm |
69 | 70 |
|
70 | 71 | tuner = Tuner(trainer) |
71 | 72 | new_batch_size = tuner.scale_batch_size(model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule) |
72 | | - assert new_batch_size == 8 |
| 73 | + assert new_batch_size == 7 # applied margin of 5% on 8 -> int(8 * 0.95) = 7 |
73 | 74 |
|
74 | 75 | if model_bs is not None: |
75 | 76 | assert model.batch_size == new_batch_size |
@@ -317,7 +318,12 @@ def test_dataloader_reset_with_scale_batch_size(tmp_path, caplog, scale_method, |
317 | 318 | # With our fix, when max_trials is reached, we don't try the doubled batch size, so we get max_trials - 1 messages |
318 | 319 | expected_tries = max_trials - 1 if init_batch_size < dataset_len and max_trials > 0 else 0 |
319 | 320 | assert caplog.text.count("trying batch size") == expected_tries |
320 | | - assert caplog.text.count("greater or equal than the length") == int(new_batch_size == dataset_len) |
| 321 | + |
| 322 | + # Determine the largest batch size that was actually tested. |
| 323 | + # For "power" this is the final found size; for "binsearch" we applied a 5% margin |
| 324 | + # when storing the final value, so the largest tested value is the one before applying margin. |
| 325 | + largest_tested_batch_size = new_batch_size if scale_method == "power" else int(math.ceil(new_batch_size * 100 / 95)) |
| 326 | + assert caplog.text.count("greater or equal than the length") == int(largest_tested_batch_size >= dataset_len) |
321 | 327 |
|
322 | 328 | assert trainer.train_dataloader.batch_size == new_batch_size |
323 | 329 | assert trainer.val_dataloaders.batch_size == new_batch_size |
@@ -453,7 +459,7 @@ def val_dataloader(self): |
453 | 459 | tuner.scale_batch_size(model, method="validate") |
454 | 460 |
|
455 | 461 |
|
456 | | -@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 100)]) |
| 462 | +@pytest.mark.parametrize(("scale_method", "expected_batch_size"), [("power", 62), ("binsearch", 95)]) |
457 | 463 | @patch("lightning.pytorch.tuner.batch_size_scaling.is_oom_error", return_value=True) |
458 | 464 | def test_dataloader_batch_size_updated_on_failure(_, tmp_path, scale_method, expected_batch_size): |
459 | 465 | class CustomBatchSizeModel(BatchSizeModel): |
@@ -611,7 +617,7 @@ def training_step(self, batch, batch_idx): |
611 | 617 | ("max_trials", "mode", "init_val", "expected"), |
612 | 618 | [ |
613 | 619 | (3, "power", 2, 8), |
614 | | - (3, "binsearch", 2, 8), |
| 620 | + (3, "binsearch", 2, 7), # applied margin of 5% on 8 -> int(8 * 0.95) = 7 |
615 | 621 | (1, "power", 4, 4), |
616 | 622 | (0, "power", 2, 2), |
617 | 623 | ], |
|
0 commit comments