@@ -486,3 +486,48 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
486
486
487
487
assert trainer .num_val_batches [0 ] == len (trainer .val_dataloaders )
488
488
assert trainer .num_val_batches [0 ] != steps_per_trial
489
+
490
+
491
+ @pytest .mark .parametrize ("margin" , [0.0 , 0.1 , 0.2 ])
492
+ def test_scale_batch_size_margin_and_max_val (tmp_path , margin ):
493
+ """Test margin feature for batch size scaling by comparing results with and without margin."""
494
+ # First, find the batch size without margin
495
+ model1 = BatchSizeModel (batch_size = 2 )
496
+ trainer1 = Trainer (default_root_dir = tmp_path , max_epochs = 1 , logger = False , enable_checkpointing = False )
497
+ tuner1 = Tuner (trainer1 )
498
+
499
+ result_without_margin = tuner1 .scale_batch_size (
500
+ model1 , mode = "binsearch" , max_trials = 2 , steps_per_trial = 1 , margin = 0.0
501
+ )
502
+
503
+ model2 = BatchSizeModel (batch_size = 2 )
504
+ trainer2 = Trainer (default_root_dir = tmp_path , max_epochs = 1 , logger = False , enable_checkpointing = False )
505
+ tuner2 = Tuner (trainer2 )
506
+
507
+ result_with_margin = tuner2 .scale_batch_size (
508
+ model2 , mode = "binsearch" , max_trials = 2 , steps_per_trial = 1 , margin = margin
509
+ )
510
+
511
+ assert result_without_margin is not None
512
+ assert result_with_margin is not None
513
+
514
+ if margin == 0.0 :
515
+ assert result_with_margin == result_without_margin
516
+ else :
517
+ expected_with_margin = max (1 , int (result_without_margin * (1 - margin )))
518
+ assert result_with_margin == expected_with_margin
519
+ assert result_with_margin <= result_without_margin
520
+
521
+
522
+ @pytest .mark .parametrize ("mode" , ["power" , "binsearch" ])
523
+ def test_scale_batch_size_max_val_limit (tmp_path , mode ):
524
+ """Test that max_val limits the batch size for both power and binsearch modes."""
525
+ model = BatchSizeModel (batch_size = 2 )
526
+ trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 )
527
+ tuner = Tuner (trainer )
528
+
529
+ max_val = 8 # Set a low max value
530
+ result = tuner .scale_batch_size (model , mode = mode , max_trials = 5 , steps_per_trial = 1 , max_val = max_val )
531
+
532
+ assert result is not None
533
+ assert result <= max_val
0 commit comments