Skip to content

Commit e8d7149

Browse files
kprokofileoll2
andauthored
Fix OOM bug on XPU (#4872)
Co-authored-by: Leonardo Lai <[email protected]>
1 parent 236f32e commit e8d7149

File tree

7 files changed

+42
-31
lines changed

7 files changed

+42
-31
lines changed

lib/src/otx/backend/native/callbacks/batchsize_finder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class BatchSizeFinder(Callback):
2727

2828
def __init__(
2929
self,
30-
steps_per_trial: int = 3,
30+
steps_per_trial: int = 5,
3131
) -> None:
3232
self._steps_per_trial = steps_per_trial
3333

@@ -52,11 +52,12 @@ def _try_loop_run(trainer: Trainer) -> None:
5252
loop.run()
5353

5454

55-
def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int) -> None:
55+
def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int, max_epochs: int = 1) -> None:
5656
trainer.logger = DummyLogger() if trainer.logger is not None else None
5757
trainer.callbacks = []
58-
# For XPU devices 1 epoch sometimes is not enough to catch an error
59-
max_epochs = 2 if is_xpu_available() else 1
58+
# For XPU devices 1 epoch sometimes is not enough to catch an error.
59+
# Emperically enlarge this to 15 iterations (steps_per_trial * epochs)
60+
max_epochs = 3 if is_xpu_available() else 1
6061

6162
loop = trainer._active_loop # noqa: SLF001
6263
if loop is None:

lib/src/otx/backend/native/tools/adaptive_bs/algorithm.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(
4747
self._max_bs = max_bs
4848
self._bs_try_history: dict[int, int] = {}
4949
self._total_mem = _get_total_memory_size()
50-
self._mem_lower_bound = 0.8 * self._total_mem
51-
self._mem_upper_bound = 0.85 * self._total_mem
50+
self._mem_lower_bound = 0.75 * self._total_mem
51+
self._mem_upper_bound = 0.9 * self._total_mem
5252
self._mp_ctx = mp.get_context("spawn")
5353

5454
def _try_batch_size(self, bs: int) -> tuple[bool, int]:
@@ -115,16 +115,16 @@ def auto_decrease_batch_size(self) -> int:
115115
if oom:
116116
logger.warning(
117117
"The auto batch size algorithm attempted to use a batch size of 2 but still "
118-
"encountered a CUDA OOM error. OTX will proceed with training at batch size 2; "
119-
"however, you will likely encounter a CUDA OOM error once training starts. "
120-
"If the issue persists, please report it accordingly.",
118+
"encountered a CUDA OOM error. OTX will proceed with training at batch size 1; "
119+
"however, it is also possible to encounter a CUDA OOM error during training.",
121120
)
122-
return 2
121+
return 1
123122
logger.warning(
124123
"Even with a batch size of 2, most of the memory is used, "
125-
"which could cause the training to fail midway.",
124+
"which could cause the training to fail midway."
125+
"For safety reasons, decrease bs to 1.",
126126
)
127-
available_bs = 2
127+
available_bs = 1
128128

129129
return available_bs
130130

@@ -157,9 +157,10 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int:
157157
raise RuntimeError(msg)
158158
logger.warning(
159159
"Even with a batch size of 2, most of the memory is used, "
160-
"which could cause the training to fail midway.",
160+
"which could cause the training to fail midway."
161+
"For safety reasons, decrease bs to 1.",
161162
)
162-
return 2
163+
return 1
163164

164165
return self.auto_decrease_batch_size()
165166

@@ -270,6 +271,8 @@ def _run_trial(train_func: Callable[[int], Any], bs: int, trial_queue: mp.Queue)
270271
or "UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY" in str(e)
271272
or "UR error" in str(e)
272273
or "UR_RESULT_ERROR_UNKNOWN" in str(e)
274+
or "UR_RESULT_ERROR_OUT_OF_HOST_MEMORY" in str(e)
275+
or "UR_RESULT_ERROR" in str(e)
273276
): # XPU OOM
274277
oom = True
275278
else:

lib/src/otx/backend/native/tools/adaptive_bs/runner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,6 @@ def _register_callback(callbacks: list[Callback] | Callback | None = None) -> li
114114

115115
def _apply_new_batch_size(engine: OTXEngine, new_batch_size: int) -> None:
116116
origin_bs = engine.datamodule.train_subset.batch_size
117-
if is_xpu_available() and new_batch_size != 1:
118-
new_batch_size -= 1 # for safety reasons
119-
if new_batch_size == origin_bs:
120-
return
121117
engine.datamodule.train_subset.batch_size = new_batch_size
122118
engine.datamodule.val_subset.batch_size = new_batch_size
123119
engine.datamodule.test_subset.batch_size = new_batch_size

lib/src/otx/tools/converter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,15 @@ def update_num_iters(param_value: int | None, config: dict) -> None:
272272
config["max_epochs"] = param_value
273273

274274

275+
def update_batch_size(param_value: int | None, config: dict) -> None:
276+
"""Update batch size in the config."""
277+
if param_value is None:
278+
logging.info("Batch size is not provided, skipping update.")
279+
return
280+
config["data"]["train_subset"]["batch_size"] = param_value
281+
config["data"]["val_subset"]["batch_size"] = param_value
282+
283+
275284
def update_early_stopping(early_stopping_cfg: dict | None, config: dict) -> None:
276285
"""Update early stopping parameters in the config."""
277286
if early_stopping_cfg is None:
@@ -483,6 +492,7 @@ def _update_params(config: dict, param_dict: dict) -> None:
483492
update_tiling(tiling, config)
484493
update_augmentations(augmentation_params, config)
485494
update_learning_rate(training_parameters.get("learning_rate", None), config)
495+
update_batch_size(training_parameters.get("batch_size", None), config)
486496
update_num_iters(training_parameters.get("max_epochs", None), config)
487497
update_early_stopping(training_parameters.get("early_stopping", None), config)
488498
update_input_size(

lib/tests/assets/geti/model_configs/detection.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ hyperparameters:
7070
enable: true
7171
patience: 10
7272
learning_rate: 0.001
73+
batch_size: 4
7374
input_size_width: 800
7475
input_size_height: 992
7576
evaluation:

lib/tests/unit/backend/native/tools/adaptive_bs/test_bs_search_algo.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def mock_train_func(batch_size) -> int:
6868
msg = "CUDA out of memory."
6969
raise RuntimeError(msg)
7070
if batch_size > max_runnable_bs:
71-
mem_usage = 8500 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs)
71+
mem_usage = 9000 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs)
7272
else:
73-
mem_usage = 8500 * batch_size / max_runnable_bs
73+
mem_usage = 9000 * batch_size / max_runnable_bs
7474

7575
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
7676
return mem_usage
@@ -110,14 +110,14 @@ def test_find_max_usable_bs_gpu_memory_too_small(self):
110110
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
111111

112112
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
113-
assert bs_search_algo.auto_decrease_batch_size() == 2
113+
assert bs_search_algo.auto_decrease_batch_size() == 1
114114

115115
def test_auto_decrease_batch_size_bs2_not_oom_but_most_mem(self):
116116
"""Batch size 2 doesn't make oom but use most of memory."""
117117
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)
118118

119119
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
120-
assert bs_search_algo.auto_decrease_batch_size() == 2
120+
assert bs_search_algo.auto_decrease_batch_size() == 1
121121

122122
@pytest.mark.parametrize(
123123
("max_runnable_bs", "max_bs", "expected_bs"),
@@ -135,22 +135,22 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs):
135135
adapted_bs = bs_search_algo.find_big_enough_batch_size()
136136

137137
if expected_bs is None:
138-
assert 7500 <= mock_train_func(adapted_bs) <= 8500
138+
assert 7500 <= mock_train_func(adapted_bs) <= 9000
139139
else:
140140
assert adapted_bs == expected_bs
141141

142142
def test_find_big_enough_batch_size_gpu_memory_too_small(self):
143143
mock_train_func = self.get_mock_train_func(cuda_oom_bound=1, max_runnable_bs=1)
144144

145145
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
146-
assert bs_search_algo.find_big_enough_batch_size() == 2
146+
assert bs_search_algo.find_big_enough_batch_size() == 1
147147

148148
def test_find_big_enough_batch_size_bs2_not_oom_but_most_mem(self):
149149
"""Batch size 2 doesn't make oom but use most of memory."""
150150
mock_train_func = self.get_mock_train_func(cuda_oom_bound=2, max_runnable_bs=1)
151151

152152
bs_search_algo = BsSearchAlgo(mock_train_func, 2, 1000)
153-
assert bs_search_algo.find_big_enough_batch_size() == 2
153+
assert bs_search_algo.find_big_enough_batch_size() == 1
154154

155155
def test_find_big_enough_batch_size_gradient_zero(self):
156156
def mock_train_func(batch_size) -> int:
@@ -167,7 +167,7 @@ def mock_train_func(batch_size) -> int:
167167
bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
168168
adapted_bs = bs_search_algo.find_big_enough_batch_size()
169169

170-
assert adapted_bs == 100
170+
assert adapted_bs == 102
171171

172172
def test_find_big_enough_batch_size_not_exceed_upper_bound(self):
173173
def mock_train_func(batch_size) -> int:
@@ -184,7 +184,7 @@ def mock_train_func(batch_size) -> int:
184184
bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
185185
adapted_bs = bs_search_algo.find_big_enough_batch_size()
186186

187-
assert mock_train_func(adapted_bs) <= 8500
187+
assert mock_train_func(adapted_bs) <= 9000
188188

189189
def test_find_big_enough_batch_size_drop_last(self):
190190
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=180)

lib/tests/unit/tools/test_converter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def test_convert(self):
1515
config = GetiConfigConverter.convert(asdict(otx_config))
1616

1717
assert config["data"]["input_size"] == (992, 800)
18-
assert config["data"]["train_subset"]["batch_size"] == 8
19-
assert config["data"]["val_subset"]["batch_size"] == 8
18+
assert config["data"]["train_subset"]["batch_size"] == 4
19+
assert config["data"]["val_subset"]["batch_size"] == 4
2020
assert config["data"]["test_subset"]["batch_size"] == 8
2121
assert config["model"]["init_args"]["optimizer"]["init_args"]["lr"] == 0.001
2222
assert config["max_epochs"] == 100
@@ -266,8 +266,8 @@ def test_instantiate(self, tmp_path):
266266
assert engine.work_dir == tmp_path
267267

268268
assert engine.datamodule.data_root == data_root
269-
assert engine.datamodule.train_subset.batch_size == 8
270-
assert engine.datamodule.val_subset.batch_size == 8
269+
assert engine.datamodule.train_subset.batch_size == 4
270+
assert engine.datamodule.val_subset.batch_size == 4
271271
assert engine.datamodule.test_subset.batch_size == 8
272272
assert engine.datamodule.train_subset.num_workers == 2
273273
assert engine.datamodule.val_subset.num_workers == 2

0 commit comments

Comments
 (0)