Skip to content

Commit 03d3e04

Browse files
committed
add testing
1 parent cef0d56 commit 03d3e04

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,48 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
486486

487487
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
488488
assert trainer.num_val_batches[0] != steps_per_trial
489+
490+
491+
@pytest.mark.parametrize("margin", [0.0, 0.1, 0.2])
492+
def test_scale_batch_size_margin_and_max_val(tmp_path, margin):
493+
"""Test margin feature for batch size scaling by comparing results with and without margin."""
494+
# First, find the batch size without margin
495+
model1 = BatchSizeModel(batch_size=2)
496+
trainer1 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
497+
tuner1 = Tuner(trainer1)
498+
499+
result_without_margin = tuner1.scale_batch_size(
500+
model1, mode="binsearch", max_trials=2, steps_per_trial=1, margin=0.0
501+
)
502+
503+
model2 = BatchSizeModel(batch_size=2)
504+
trainer2 = Trainer(default_root_dir=tmp_path, max_epochs=1, logger=False, enable_checkpointing=False)
505+
tuner2 = Tuner(trainer2)
506+
507+
result_with_margin = tuner2.scale_batch_size(
508+
model2, mode="binsearch", max_trials=2, steps_per_trial=1, margin=margin
509+
)
510+
511+
assert result_without_margin is not None
512+
assert result_with_margin is not None
513+
514+
if margin == 0.0:
515+
assert result_with_margin == result_without_margin
516+
else:
517+
expected_with_margin = max(1, int(result_without_margin * (1 - margin)))
518+
assert result_with_margin == expected_with_margin
519+
assert result_with_margin <= result_without_margin
520+
521+
522+
@pytest.mark.parametrize("mode", ["power", "binsearch"])
523+
def test_scale_batch_size_max_val_limit(tmp_path, mode):
524+
"""Test that max_val limits the batch size for both power and binsearch modes."""
525+
model = BatchSizeModel(batch_size=2)
526+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
527+
tuner = Tuner(trainer)
528+
529+
max_val = 8 # Set a low max value
530+
result = tuner.scale_batch_size(model, mode=mode, max_trials=5, steps_per_trial=1, max_val=max_val)
531+
532+
assert result is not None
533+
assert result <= max_val

0 commit comments

Comments
 (0)