Skip to content

Commit aceebda

Browse files
Songki Choisungmanceunwoosh
authored
Fix CPU training issue on non-CUDA system (#2655)
Fix bug that auto adaptive batch size raises an error if CUDA isn't available (#2410) --------- Co-authored-by: Sungman Cho <[email protected]> Co-authored-by: Eunwoo Shin <[email protected]>
1 parent a0780a8 commit aceebda

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Callable, Dict, List
99

1010
import numpy as np
11+
from torch.cuda import is_available as cuda_available
1112

1213
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
1314
from otx.algorithms.common.utils.logger import get_logger
@@ -53,6 +54,10 @@ def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool =
5354
not_increase (bool) : Whether adapting batch size to larger value than default value or not.
5455
"""
5556

57+
if not cuda_available():
58+
logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.")
59+
return
60+
5661
def train_func_single_iter(batch_size):
5762
copied_cfg = deepcopy(cfg)
5863
_set_batch_size(copied_cfg, batch_size)

tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ def test_adapt_batch_size(
109109
assert len(mock_train_func.call_args_list[0].kwargs["cfg"].custom_hooks) == 1
110110

111111

112+
def test_adapt_batch_size_no_gpu(mocker, common_cfg, mock_dataset):
113+
# prepare
114+
mock_train_func = mocker.MagicMock()
115+
mock_config = set_mock_cfg_not_action(common_cfg)
116+
mocker.patch.object(automatic_bs, "cuda_available", return_value=False)
117+
118+
# execute
119+
adapt_batch_size(mock_train_func, mock_config, mock_dataset, False, True)
120+
121+
# check train function ins't called.
122+
mock_train_func.assert_not_called()
123+
124+
112125
class TestSubDataset:
113126
@pytest.fixture(autouse=True)
114127
def set_up(self, mocker):

0 commit comments

Comments
 (0)