Skip to content

Commit f225bb3

Browse files
authored
Fix adaptive batch size to run on CPU (#4499)
* add warning instead of raising error * fix unit test
1 parent 6a3f198 commit f225bb3

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

src/otx/backend/native/lightning/accelerators/xpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,6 @@ def teardown(self) -> None:
6060
AcceleratorRegistry.register(
6161
XPUAccelerator.accelerator_name,
6262
XPUAccelerator,
63+
override=True,
6364
description="Accelerator supports XPU devices",
6465
)

src/otx/backend/native/lightning/strategies/xpu_single.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,6 @@ def __init__(
4747
StrategyRegistry.register(
4848
SingleXPUStrategy.strategy_name,
4949
SingleXPUStrategy,
50+
override=True,
5051
description="Strategy that enables training on single XPU",
5152
)

src/otx/backend/native/tools/adaptive_bs/runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def adapt_batch_size(
4242
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
4343
"""
4444
if not (is_cuda_available() or is_xpu_available()):
45-
msg = "Adaptive batch size supports CUDA or XPU."
46-
raise RuntimeError(msg)
45+
msg = "Adaptive batch size supports only CUDA or XPU."
46+
logger.warning(msg)
47+
return
4748

4849
engine.model.patch_optimizer_and_scheduler_for_adaptive_bs()
4950
default_bs = engine.datamodule.train_subset.batch_size

tests/unit/backend/native/tools/adaptive_bs/test_adaptive_bs_api.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,18 @@ def test_adapt_batch_size_dist_sub_proc(
180180
assert int(mock_os.environ["ADAPTIVE_BS_FOR_DIST"]) == cur_bs
181181

182182

183-
def test_adapt_batch_size_no_accelerator(
183+
def test_adapt_batch_size_cpu(
184184
mock_is_cuda_available,
185185
mock_is_xpu_available,
186186
mock_engine,
187187
train_args,
188+
mocker,
188189
):
189190
mock_is_cuda_available.return_value = False
190-
with pytest.raises(RuntimeError, match="Adaptive batch size supports CUDA or XPU."):
191-
adapt_batch_size(mock_engine, **train_args)
191+
mock_is_xpu_available.return_value = False
192+
mock_logger = mocker.patch("otx.backend.native.tools.adaptive_bs.runner.logger")
193+
adapt_batch_size(mock_engine, **train_args)
194+
mock_logger.warning.assert_called_once_with("Adaptive batch size supports only CUDA or XPU.")
192195

193196

194197
def test_adjust_train_args(train_args):

0 commit comments

Comments
 (0)