Skip to content

Commit 73ea1b0

Browse files
authored
Workaround for batch size search on xpu devices (#4513)
* provide workaround for XPU batch search * return back parameters for MaskRCNN * fix unit test * switch off adaprive_bs by default
1 parent ab56cdf commit 73ea1b0

File tree

7 files changed

+25
-11
lines changed

7 files changed

+25
-11
lines changed

src/otx/backend/native/callbacks/batchsize_finder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from lightning.pytorch.callbacks import Callback
1111
from lightning.pytorch.loggers.logger import DummyLogger
1212

13+
from otx.utils.device import is_xpu_available
14+
1315
if TYPE_CHECKING:
1416
from lightning import LightningModule
1517
from lightning.pytorch.trainer import Trainer
@@ -53,13 +55,15 @@ def _try_loop_run(trainer: Trainer) -> None:
5355
def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int) -> None:
5456
trainer.logger = DummyLogger() if trainer.logger is not None else None
5557
trainer.callbacks = []
58+
# For XPU devices 1 epoch sometimes is not enough to catch an error
59+
max_epochs = 2 if is_xpu_available() else 1
5660

5761
loop = trainer._active_loop # noqa: SLF001
5862
if loop is None:
5963
msg = "There is no active loop."
6064
raise RuntimeError(msg)
6165
if trainer.fit_loop.epoch_loop.max_steps == -1: # epoch based loop
62-
trainer.fit_loop.max_epochs = 1
66+
trainer.fit_loop.max_epochs = max_epochs
6367
trainer.limit_train_batches = steps_per_trial
6468
else: # iter based loop
6569
trainer.fit_loop.epoch_loop.max_steps = steps_per_trial

src/otx/backend/native/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def train(
162162
adaptive_bs: Literal["None", "Safe", "Full"] = "None",
163163
check_val_every_n_epoch: int | None = 1,
164164
num_sanity_val_steps: int | None = 0,
165+
log_every_n_steps: int | None = 1,
165166
**kwargs,
166167
) -> dict[str, Any]:
167168
r"""Trains the model using the provided LightningModule and OTXDataModule.
@@ -245,6 +246,7 @@ def train(
245246
val_check_interval=val_check_interval,
246247
check_val_every_n_epoch=check_val_every_n_epoch,
247248
num_sanity_val_steps=num_sanity_val_steps,
249+
log_every_n_steps=log_every_n_steps,
248250
**kwargs,
249251
)
250252
fit_kwargs: dict[str, Any] = {}

src/otx/backend/native/tools/adaptive_bs/algorithm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def _run_trial(train_func: Callable[[int], Any], bs: int, trial_queue: mp.Queue)
269269
or "XPU out of memory" in str(e)
270270
or "UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY" in str(e)
271271
or "UR error" in str(e)
272+
or "UR_RESULT_ERROR_UNKNOWN" in str(e)
272273
): # XPU OOM
273274
oom = True
274275
else:

src/otx/backend/native/tools/adaptive_bs/runner.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
def adapt_batch_size(
2929
engine: OTXEngine,
3030
not_increase: bool = True,
31-
callbacks: list[Callback] | Callback | None = None,
3231
**train_args,
3332
) -> None:
3433
"""Change the actual batch size depending on the current GPU status.
@@ -39,7 +38,6 @@ def adapt_batch_size(
3938
Args:
4039
engine (OTXEngine): engine instnace.
4140
not_increase (bool) : Whether adapting batch size to larger value than default value or not.
42-
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
4341
"""
4442
if not (is_cuda_available() or is_xpu_available()):
4543
msg = "Adaptive batch size supports only CUDA or XPU."
@@ -55,7 +53,7 @@ def adapt_batch_size(
5553
_apply_new_batch_size(engine, new_batch_size)
5654
return
5755

58-
train_func = partial(_train_model, engine=engine, callbacks=callbacks, **_adjust_train_args(train_args))
56+
train_func = partial(_train_model, engine=engine, **_adjust_train_args(train_args))
5957
bs_search_algo = BsSearchAlgo(
6058
train_func=train_func,
6159
default_bs=default_bs,
@@ -85,11 +83,12 @@ def adapt_batch_size(
8583
def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]:
8684
train_args.update(train_args.pop("kwargs", {}))
8785
train_args.pop("self", None)
88-
train_args.pop("adaptive_bs")
86+
train_args.pop("adaptive_bs", None)
87+
train_args.pop("callbacks", None)
8988
return train_args
9089

9190

92-
def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callback | None = None, **train_args) -> None:
91+
def _train_model(bs: int, engine: OTXEngine, **train_args) -> None:
9392
if bs <= 0:
9493
msg = f"Batch size should be greater than 0, but {bs} is given."
9594
raise ValueError(msg)
@@ -100,7 +99,8 @@ def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callbac
10099
engine.datamodule.val_subset.batch_size = bs
101100
engine.datamodule.test_subset.batch_size = bs
102101
train_args["adaptive_bs"] = "None"
103-
engine.train(callbacks=_register_callback(callbacks), **train_args)
102+
print(f"Runnning training trial with bs = {bs} ...")
103+
engine.train(callbacks=_register_callback(), **train_args)
104104

105105

106106
def _register_callback(callbacks: list[Callback] | Callback | None = None) -> list[Callback]:
@@ -114,9 +114,13 @@ def _register_callback(callbacks: list[Callback] | Callback | None = None) -> li
114114

115115
def _apply_new_batch_size(engine: OTXEngine, new_batch_size: int) -> None:
116116
origin_bs = engine.datamodule.train_subset.batch_size
117+
if is_xpu_available() and new_batch_size != 1:
118+
new_batch_size -= 1 # for safety reasons
117119
if new_batch_size == origin_bs:
118120
return
119121
engine.datamodule.train_subset.batch_size = new_batch_size
120122
engine.datamodule.val_subset.batch_size = new_batch_size
121123
engine.datamodule.test_subset.batch_size = new_batch_size
122-
engine.model.optimizer_callable.optimizer_kwargs["lr"] *= sqrt(new_batch_size / origin_bs) # type: ignore[attr-defined]
124+
new_lr = engine.model.optimizer_callable.optimizer_kwargs["lr"] * sqrt(new_batch_size / origin_bs) # type: ignore[attr-defined]
125+
print(f"new batch size = {new_batch_size} with learning rate = {new_lr} is set for the training and validation.")
126+
engine.model.optimizer_callable.optimizer_kwargs["lr"] = new_lr # type: ignore[attr-defined]

src/otx/recipe/instance_segmentation/maskrcnn_r50_tv.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ overrides:
5757
data:
5858
train_subset:
5959
batch_size: 4
60-
num_workers: 8
60+
num_workers: 4
6161

6262
val_subset:
6363
num_workers: 4

tests/unit/backend/native/tools/adaptive_bs/test_adaptive_bs_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_train_model,
1818
adapt_batch_size,
1919
)
20+
from otx.utils.device import is_xpu_available
2021

2122

2223
@pytest.fixture()
@@ -263,7 +264,7 @@ def test_on_fit_start(self, mock_trainer, mock_active_loop):
263264
# check steps_per_trial is set well
264265
assert mock_trainer.limit_val_batches == steps_per_trial
265266
assert mock_trainer.fit_loop.epoch_loop.max_steps == -1
266-
assert mock_trainer.fit_loop.max_epochs == 1
267+
assert mock_trainer.fit_loop.max_epochs == 1 if not is_xpu_available() else 2
267268
assert mock_trainer.limit_train_batches == steps_per_trial
268269
# check active_loop is run
269270
assert mock_active_loop.restarting is False
@@ -281,7 +282,7 @@ def test_on_fit_start_no_val(self, mock_trainer, mock_active_loop):
281282
# check steps_per_trial is set well
282283
assert mock_trainer.limit_val_batches == 0
283284
assert mock_trainer.fit_loop.epoch_loop.max_steps == -1
284-
assert mock_trainer.fit_loop.max_epochs == 1
285+
assert mock_trainer.fit_loop.max_epochs == 1 if not is_xpu_available() else 2
285286
assert mock_trainer.limit_train_batches == steps_per_trial
286287
# check active_loop is run
287288
assert mock_active_loop.restarting is False

tests/unit/tools/test_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,5 @@ def test_instantiate(self, tmp_path):
114114
if "logger" in train_kwargs and train_kwargs["logger"] is not None:
115115
assert len(train_kwargs["logger"]) == len(config["logger"])
116116
assert train_kwargs["max_epochs"] == 100
117+
assert "adaptive_bs" in train_kwargs
118+
assert train_kwargs["adaptive_bs"] == "Safe"

0 commit comments

Comments
 (0)