@@ -68,9 +68,9 @@ def mock_train_func(batch_size) -> int:
6868 msg = "CUDA out of memory."
6969 raise RuntimeError (msg )
7070 if batch_size > max_runnable_bs :
71- mem_usage = 8500 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs )
71+ mem_usage = 9000 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs )
7272 else :
73- mem_usage = 8500 * batch_size / max_runnable_bs
73+ mem_usage = 9000 * batch_size / max_runnable_bs
7474
7575 self .mock_torch .cuda .max_memory_reserved .return_value = mem_usage
7676 return mem_usage
@@ -110,14 +110,14 @@ def test_find_max_usable_bs_gpu_memory_too_small(self):
110110 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 1 , max_runnable_bs = 1 )
111111
112112 bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
113- assert bs_search_algo .auto_decrease_batch_size () == 2
113+ assert bs_search_algo .auto_decrease_batch_size () == 1
114114
115115 def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem (self ):
116116 """Batch size 2 doesn't make oom but use most of memory."""
117117 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 2 , max_runnable_bs = 1 )
118118
119119 bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
120- assert bs_search_algo .auto_decrease_batch_size () == 2
120+ assert bs_search_algo .auto_decrease_batch_size () == 1
121121
122122 @pytest .mark .parametrize (
123123 ("max_runnable_bs" , "max_bs" , "expected_bs" ),
@@ -135,22 +135,22 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs):
135135 adapted_bs = bs_search_algo .find_big_enough_batch_size ()
136136
137137 if expected_bs is None :
138- assert 7500 <= mock_train_func (adapted_bs ) <= 8500
138+ assert 7500 <= mock_train_func (adapted_bs ) <= 9000
139139 else :
140140 assert adapted_bs == expected_bs
141141
142142 def test_find_big_enough_batch_size_gpu_memory_too_small (self ):
143143 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 1 , max_runnable_bs = 1 )
144144
145145 bs_search_algo = BsSearchAlgo (mock_train_func , 128 , 1000 )
146- assert bs_search_algo .find_big_enough_batch_size () == 2
146+ assert bs_search_algo .find_big_enough_batch_size () == 1
147147
148148 def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem (self ):
149149 """Batch size 2 doesn't make oom but use most of memory."""
150150 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 2 , max_runnable_bs = 1 )
151151
152152 bs_search_algo = BsSearchAlgo (mock_train_func , 2 , 1000 )
153- assert bs_search_algo .find_big_enough_batch_size () == 2
153+ assert bs_search_algo .find_big_enough_batch_size () == 1
154154
155155 def test_find_big_enough_batch_size_gradient_zero (self ):
156156 def mock_train_func (batch_size ) -> int :
@@ -167,7 +167,7 @@ def mock_train_func(batch_size) -> int:
167167 bs_search_algo = BsSearchAlgo (mock_train_func , 64 , 1000 )
168168 adapted_bs = bs_search_algo .find_big_enough_batch_size ()
169169
170- assert adapted_bs == 100
170+ assert adapted_bs == 102
171171
172172 def test_find_big_enough_batch_size_not_exceed_upper_bound (self ):
173173 def mock_train_func (batch_size ) -> int :
@@ -184,7 +184,7 @@ def mock_train_func(batch_size) -> int:
184184 bs_search_algo = BsSearchAlgo (mock_train_func , 64 , 1000 )
185185 adapted_bs = bs_search_algo .find_big_enough_batch_size ()
186186
187- assert mock_train_func (adapted_bs ) <= 8500
187+ assert mock_train_func (adapted_bs ) <= 9000
188188
189189 def test_find_big_enough_batch_size_drop_last (self ):
190190 mock_train_func = self .get_mock_train_func (cuda_oom_bound = 10000 , max_runnable_bs = 180 )
0 commit comments