@@ -99,12 +99,19 @@ def test_auto_decrease_batch_size(self):
9999 assert adapted_bs == 80
100100
101101 def test_find_max_usable_bs_gpu_memory_too_small (self ):
102- mock_train_func = self .get_mock_train_func (cuda_oom_bound = 4 , max_runnable_bs = 1 )
102+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 1 , max_runnable_bs = 1 )
103103
104104 bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
105105 with pytest .raises (RuntimeError ):
106106 bs_search_algo .auto_decrease_batch_size ()
107107
108+ def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem (self ):
109+ """Batch size 2 doesn't make oom but use most of memory."""
110+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 2 , max_runnable_bs = 1 )
111+
112+ bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
113+ assert bs_search_algo .auto_decrease_batch_size () == 2
114+
108115 @pytest .mark .parametrize (
109116 ("max_runnable_bs" , "max_bs" , "expected_bs" ),
110117 [
@@ -126,12 +133,19 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs):
126133 assert adapted_bs == expected_bs
127134
128135 def test_find_big_enough_batch_size_gpu_memory_too_small (self ):
129- mock_train_func = self .get_mock_train_func (cuda_oom_bound = 4 , max_runnable_bs = 1 )
136+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 1 , max_runnable_bs = 1 )
130137
131138 bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
132139 with pytest .raises (RuntimeError ):
133140 bs_search_algo .find_big_enough_batch_size ()
134141
142+ def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem (self ):
143+ """Batch size 2 doesn't make oom but use most of memory."""
144+ mock_train_func = self .get_mock_train_func (cuda_oom_bound = 2 , max_runnable_bs = 1 )
145+
146+ bs_search_algo = BsSearchAlgo (mock_train_func , 2 , 1000 )
147+ assert bs_search_algo .find_big_enough_batch_size () == 2
148+
135149 def test_find_big_enough_batch_size_gradient_zero (self ):
136150 def mock_train_func (batch_size ) -> int :
137151 if batch_size > 1000 :
0 commit comments