File tree Expand file tree Collapse file tree 4 files changed +11
-5
lines changed
tests/unit/backend/native/tools/adaptive_bs Expand file tree Collapse file tree 4 files changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -60,5 +60,6 @@ def teardown(self) -> None:
6060AcceleratorRegistry .register (
6161 XPUAccelerator .accelerator_name ,
6262 XPUAccelerator ,
63+ override = True ,
6364 description = "Accelerator supports XPU devices" ,
6465)
Original file line number Diff line number Diff line change @@ -47,5 +47,6 @@ def __init__(
4747StrategyRegistry .register (
4848 SingleXPUStrategy .strategy_name ,
4949 SingleXPUStrategy ,
50+ override = True ,
5051 description = "Strategy that enables training on single XPU" ,
5152)
Original file line number Diff line number Diff line change @@ -42,8 +42,9 @@ def adapt_batch_size(
4242 callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
4343 """
4444 if not (is_cuda_available () or is_xpu_available ()):
45- msg = "Adaptive batch size supports CUDA or XPU."
46- raise RuntimeError (msg )
45+ msg = "Adaptive batch size supports only CUDA or XPU."
46+ logger .warning (msg )
47+ return
4748
4849 engine .model .patch_optimizer_and_scheduler_for_adaptive_bs ()
4950 default_bs = engine .datamodule .train_subset .batch_size
Original file line number Diff line number Diff line change @@ -180,15 +180,18 @@ def test_adapt_batch_size_dist_sub_proc(
180180 assert int (mock_os .environ ["ADAPTIVE_BS_FOR_DIST" ]) == cur_bs
181181
182182
183- def test_adapt_batch_size_no_accelerator (
183+ def test_adapt_batch_size_cpu (
184184 mock_is_cuda_available ,
185185 mock_is_xpu_available ,
186186 mock_engine ,
187187 train_args ,
188+ mocker ,
188189):
189190 mock_is_cuda_available .return_value = False
190- with pytest .raises (RuntimeError , match = "Adaptive batch size supports CUDA or XPU." ):
191- adapt_batch_size (mock_engine , ** train_args )
191+ mock_is_xpu_available .return_value = False
192+ mock_logger = mocker .patch ("otx.backend.native.tools.adaptive_bs.runner.logger" )
193+ adapt_batch_size (mock_engine , ** train_args )
194+ mock_logger .warning .assert_called_once_with ("Adaptive batch size supports only CUDA or XPU." )
192195
193196
194197def test_adjust_train_args (train_args ):
You can’t perform that action at this time.
0 commit comments