Skip to content

Commit 7744c89

Browse files
authored
Relieve memory usage criteria on batch size 2 during adaptive_bs (#4009)
* release memory usage cirteria on batch size 2 during adpative_bs * update unit test * update unit test
1 parent 758ea97 commit 7744c89

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

src/otx/engine/adaptive_bs/bs_search_algo.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,14 @@ def auto_decrease_batch_size(self) -> int:
112112
break
113113

114114
if available_bs == 0:
115-
msg = "Current device can't train model even with 2."
116-
raise RuntimeError(msg)
115+
if oom:
116+
msg = "Current device can't train model even with 2."
117+
raise RuntimeError(msg)
118+
logger.warning(
119+
"Even with a batch size of 2, most of the memory is used, "
120+
"which could cause the training to fail midway.",
121+
)
122+
available_bs = 2
117123

118124
return available_bs
119125

@@ -141,8 +147,14 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int:
141147
if oom or bs_mem_usage > self._mem_upper_bound:
142148
self._default_bs -= 2
143149
if self._default_bs <= 0:
144-
msg = "Current device can't train model even with 2."
145-
raise RuntimeError(msg)
150+
if oom:
151+
msg = "Current device can't train model even with 2."
152+
raise RuntimeError(msg)
153+
logger.warning(
154+
"Even with a batch size of 2, most of the memory is used, "
155+
"which could cause the training to fail midway.",
156+
)
157+
return 2
146158

147159
return self.auto_decrease_batch_size()
148160

tests/unit/engine/adaptive_bs/test_bs_search_algo.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,19 @@ def test_auto_decrease_batch_size(self):
9999
assert adapted_bs == 80
100100

101101
def test_find_max_usable_bs_gpu_memory_too_small(self):
102-
mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1)
102+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
103103

104104
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
105105
with pytest.raises(RuntimeError):
106106
bs_search_algo.auto_decrease_batch_size()
107107

108+
def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self):
109+
"""Batch size 2 doesn't make oom but use most of memory."""
110+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)
111+
112+
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
113+
assert bs_search_algo.auto_decrease_batch_size() == 2
114+
108115
@pytest.mark.parametrize(
109116
("max_runnable_bs", "max_bs", "expected_bs"),
110117
[
@@ -126,12 +133,19 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs):
126133
assert adapted_bs == expected_bs
127134

128135
def test_find_big_enough_batch_size_gpu_memory_too_small(self):
129-
mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1)
136+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
130137

131138
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
132139
with pytest.raises(RuntimeError):
133140
bs_search_algo.find_big_enough_batch_size()
134141

142+
def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self):
143+
"""Batch size 2 doesn't make oom but use most of memory."""
144+
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)
145+
146+
bs_search_algo = BsSearchAlgo(mock_train_func, 2, 1000)
147+
assert bs_search_algo.find_big_enough_batch_size() == 2
148+
135149
def test_find_big_enough_batch_size_gradient_zero(self):
136150
def mock_train_func(batch_size) -> int:
137151
if batch_size > 1000:

0 commit comments

Comments
 (0)