Skip to content

Commit d705378

Browse files
authored
Decrease batch size if CUDA OOM occurs (#2022)
* implement adpating bs * refine impl * implement adaptive bs also in cls, seg task * refine adapt bs algo to consider gpu util * refactor code * write comment and docstring * implement decreasig bs on action task * update learning rate after decreasing batch size * implement test code of mmcv automatic_bs file * remove meta modification * remove unused improt * implement test code of torch automatic_bs file * align with pre commit * add line to tell not supporting anomaly * update CHANGELOG * update docs * change argument help * change file name * apply pr comment * add auto_decrease_bs in learning parameters * align with pre commit * fix typo * add integration test * bugfix * update test code * not execute algo in nncf * suppor nncf * apply comment * align with pre commit * change method to set value * refine warning comment * remove breakpoint * make hpo not use auto decrease batch size * refine warning & typo fix * align with pre commit * Update CHANGELOG.md Co-authored-by: Sungman Cho <[email protected]> * update unit test * update integration test * bufix
1 parent 0d3002a commit d705378

File tree

27 files changed

+633
-1
lines changed

27 files changed

+633
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ All notable changes to this project will be documented in this file.
2323
- Segmentation task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1977>)
2424
- Action task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1993>)
2525
- Optimize data preprocessing time and enhance overall performance in semantic segmentation (<https://github.com/openvinotoolkit/training_extensions/pull/2020>)
26+
- Support automatic batch size decrease when there is no enough GPU memory (<https://github.com/openvinotoolkit/training_extensions/pull/2022>)
2627

2728
### Bug fixes
2829

otx/algorithms/action/adapters/mmaction/task.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
import time
2020
from copy import deepcopy
21+
from functools import partial
2122
from typing import Optional, Union
2223

2324
import torch
@@ -34,6 +35,7 @@
3435
)
3536
from otx.algorithms.action.task import OTXActionTask
3637
from otx.algorithms.common.adapters.mmcv.utils import (
38+
adapt_batch_size,
3739
build_data_parallel,
3840
get_configs_by_pairs,
3941
patch_adaptive_interval_training,
@@ -263,6 +265,11 @@ def _train_model(
263265
torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
264266

265267
validate = bool(cfg.data.get("val", None))
268+
269+
if self._hyperparams.learning_parameters.auto_decrease_batch_size:
270+
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
271+
adapt_batch_size(train_func, cfg, datasets, validate)
272+
266273
train_model(
267274
model,
268275
datasets,

otx/algorithms/action/configs/classification/configuration.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,23 @@ learning_parameters:
198198
type: UI_RULES
199199
visible_in_ui: true
200200
warning: This will automatically control the patience and interval when early stopping is enabled.
201+
auto_decrease_batch_size:
202+
affects_outcome_of: TRAINING
203+
default_value: false
204+
description: Find a proper batch size by training for an iteration with various batch size a few times.
205+
editable: true
206+
header: Decrease batch size if current batch size isn't fit to CUDA memory.
207+
type: BOOLEAN
208+
ui_rules:
209+
action: DISABLE_EDITING
210+
operator: AND
211+
rules: []
212+
type: UI_RULES
213+
visible_in_ui: true
214+
warning:
215+
Enabling this option could reduce the actual batch size if the current setting results in out-of-memory error.
216+
The learning rate also could be adjusted according to the adapted batch size.
217+
This process might take some extra computation time to try a few batch size candidates.
201218
type: PARAMETER_GROUP
202219
visible_in_ui: true
203220
postprocessing:

otx/algorithms/action/configs/detection/configuration.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,23 @@ learning_parameters:
198198
type: UI_RULES
199199
visible_in_ui: true
200200
warning: This will automatically control the patience and interval when early stopping is enabled.
201+
auto_decrease_batch_size:
202+
affects_outcome_of: TRAINING
203+
default_value: false
204+
description: Find a proper batch size by training for an iteration with various batch size a few times.
205+
editable: true
206+
header: Decrease batch size if current batch size isn't fit to CUDA memory.
207+
type: BOOLEAN
208+
ui_rules:
209+
action: DISABLE_EDITING
210+
operator: AND
211+
rules: []
212+
type: UI_RULES
213+
visible_in_ui: true
214+
warning:
215+
Enabling this option could reduce the actual batch size if the current setting results in out-of-memory error.
216+
The learning rate also could be adjusted according to the adapted batch size.
217+
This process might take some extra computation time to try a few batch size candidates.
201218
type: PARAMETER_GROUP
202219
visible_in_ui: true
203220
postprocessing:

otx/algorithms/classification/adapters/mmcls/task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
from contextlib import nullcontext
2121
from copy import deepcopy
22+
from functools import partial
2223
from typing import Any, Dict, Optional, Union
2324

2425
import torch
@@ -40,6 +41,7 @@
4041
ReciproCAMHook,
4142
)
4243
from otx.algorithms.common.adapters.mmcv.utils import (
44+
adapt_batch_size,
4345
build_data_parallel,
4446
get_configs_by_pairs,
4547
patch_data_pipeline,
@@ -53,6 +55,7 @@
5355
update_or_add_custom_hook,
5456
)
5557
from otx.algorithms.common.configs.training_base import TrainType
58+
from otx.algorithms.common.tasks.nncf_task import NNCFBaseTask
5659
from otx.algorithms.common.utils import set_random_seed
5760
from otx.algorithms.common.utils.data import get_dataset
5861
from otx.algorithms.common.utils.logger import get_logger
@@ -404,6 +407,11 @@ def _train_model(
404407
)
405408
)
406409

410+
if self._hyperparams.learning_parameters.auto_decrease_batch_size:
411+
validate = isinstance(self, NNCFBaseTask) # nncf needs eval hooks
412+
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
413+
adapt_batch_size(train_func, cfg, datasets, validate)
414+
407415
train_model(
408416
model,
409417
datasets,

otx/algorithms/classification/configs/configuration.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,23 @@ learning_parameters:
221221
type: UI_RULES
222222
visible_in_ui: true
223223
warning: null
224+
auto_decrease_batch_size:
225+
affects_outcome_of: TRAINING
226+
default_value: false
227+
description: Find a proper batch size by training for an iteration with various batch size a few times.
228+
editable: true
229+
header: Decrease batch size if current batch size isn't fit to CUDA memory.
230+
type: BOOLEAN
231+
ui_rules:
232+
action: DISABLE_EDITING
233+
operator: AND
234+
rules: []
235+
type: UI_RULES
236+
visible_in_ui: true
237+
warning:
238+
Enabling this option could reduce the actual batch size if the current setting results in out-of-memory error.
239+
The learning rate also could be adjusted according to the adapted batch size.
240+
This process might take some extra computation time to try a few batch size candidates.
224241
type: PARAMETER_GROUP
225242
visible_in_ui: true
226243
pot_parameters:

otx/algorithms/common/adapters/mmcv/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""OTX Adapters - mmcv.utils."""
22

3-
# Copyright (C) 2022 Intel Corporation
3+
# Copyright (C) 2023 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

66
from ._builder_build_data_parallel import build_data_parallel
77
from ._config_utils_get_configs_by_keys import get_configs_by_keys
88
from ._config_utils_get_configs_by_pairs import get_configs_by_pairs
9+
from .automatic_bs import adapt_batch_size
910
from .builder import build_dataloader, build_dataset
1011
from .config_utils import (
1112
MPAConfig,
@@ -58,4 +59,5 @@
5859
"prepare_work_dir",
5960
"get_data_cfg",
6061
"MPAConfig",
62+
"adapt_batch_size",
6163
]
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Algorithm to find a proper batch size which is fit to current GPU device for tasks using mmcv."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from copy import deepcopy
7+
from typing import Callable, Dict, List
8+
9+
import numpy as np
10+
11+
from otx.algorithms.common.adapters.torch.utils import adapt_batch_size as adapt_torch_model_bs
12+
from otx.algorithms.common.utils.logger import get_logger
13+
14+
logger = get_logger()
15+
16+
17+
def _set_value_at_dict_in_dict(target: Dict, key_path: str, value):
18+
"""Set value at dictionary hierarchy structure.
19+
20+
This function is for setting a value at leaf dictionary node in dictionary hierarchy structure.
21+
If key doesn't exist in the middle node dictionaray, then make a new dictionary at that and keep going.
22+
For example, if you want to set value at target["a"]["b"]["c"], then you can call the function as below.
23+
_set_value_at_dict_in_dict(target, "a.b.c", value)
24+
25+
Args:
26+
target (Dict): Target variable.
27+
key_path (str): Dot delimited dictionary key string.
28+
value : Value to set.
29+
"""
30+
keys = key_path.split(".")
31+
for key in keys[:-1]:
32+
if key not in target:
33+
target[key] = {}
34+
target = target[key]
35+
36+
target[keys[-1]] = value
37+
38+
39+
def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False):
40+
"""Decrease batch size if default batch size isn't fit to current GPU device.
41+
42+
This function just setup for single iteration training to reduce time for adapting.
43+
The core part of adapting batch size is done in adapt_batch_size in the torch.utils package.
44+
45+
Args:
46+
train_func (Callable): The function to train a model.
47+
Only cfg, dataset and meta are passed to the function when invoking it.
48+
cfg: Configuration of a training.
49+
meta (Dict): A dict records some meta information of a training.
50+
datasets (List): List of datasets.
51+
validate (bool): Whether do vlidation or not.
52+
"""
53+
54+
def train_func_single_iter(batch_size):
55+
copied_cfg = deepcopy(cfg)
56+
_set_batch_size(copied_cfg, batch_size)
57+
58+
# setup for training a single iter to reduce time
59+
if copied_cfg.runner.get("type") == "AccuracyAwareRunner": # nncf case
60+
if "nncf_config" in copied_cfg.runner:
61+
_set_value_at_dict_in_dict(
62+
copied_cfg.runner["nncf_config"], "accuracy_aware_training.params.maximal_total_epochs", 1
63+
)
64+
else:
65+
copied_cfg.runner["max_epochs"] = 1
66+
67+
if not validate: # disable validation
68+
for hook in copied_cfg.custom_hooks:
69+
if hook["type"] == "AdaptiveTrainSchedulingHook":
70+
hook["enable_eval_before_run"] = False
71+
72+
new_datasets = [SubDataset(datasets[0], batch_size)]
73+
74+
train_func(
75+
dataset=new_datasets,
76+
cfg=copied_cfg,
77+
validate=validate,
78+
)
79+
80+
default_bs = _get_batch_size(cfg)
81+
82+
available_bs = adapt_torch_model_bs(
83+
train_func=train_func_single_iter,
84+
current_bs=default_bs,
85+
trainset_size=len(datasets[0]),
86+
)
87+
88+
if default_bs != available_bs:
89+
_set_batch_size(cfg, available_bs)
90+
origin_lr = cfg.optimizer.lr
91+
cfg.optimizer.lr *= available_bs / default_bs
92+
93+
logger.info("Adapting batch size is done.")
94+
logger.info(f"Batch size is adapted : {default_bs} -> {available_bs}")
95+
logger.info(f"learning rate is adapted : {origin_lr} -> {cfg.optimizer.lr}")
96+
else:
97+
logger.info("Adapting batch size is done. Current batch size is availble.")
98+
99+
100+
def _get_batch_size(cfg) -> int:
101+
if "action" in str(cfg.domain).lower():
102+
return cfg.data.videos_per_gpu
103+
return cfg.data.train_dataloader["samples_per_gpu"]
104+
105+
106+
def _set_batch_size(cfg, batch_size: int):
107+
if "action" in str(cfg.domain).lower():
108+
cfg.data.videos_per_gpu = batch_size
109+
else:
110+
cfg.data.train_dataloader["samples_per_gpu"] = batch_size
111+
112+
113+
class SubDataset:
114+
"""Wrapper class to make dataset pretend to have specified number of images.
115+
116+
Args:
117+
fullset: Original dataset.
118+
num_samples (int): Number of images to pretend to have. It should be positive.
119+
"""
120+
121+
def __init__(self, fullset, num_sampels: int):
122+
if num_sampels <= 0:
123+
raise ValueError(f"num_sampels should be positive. But, current value is {num_sampels}.")
124+
125+
self.fullset = fullset
126+
self.num_sampels = num_sampels
127+
128+
def __len__(self) -> int:
129+
"""Get length of subset."""
130+
return self.num_sampels
131+
132+
def __getitem__(self, indx) -> dict:
133+
"""Get dataset at index."""
134+
return self.fullset[indx]
135+
136+
def __getattr__(self, name):
137+
"""When trying to get other attributes, not dataset, get values from fullset."""
138+
if name == "__setstate__":
139+
raise AttributeError(name)
140+
return getattr(self.fullset, name)
141+
142+
@property
143+
def flag(self):
144+
"""Getter of flag for detection task.
145+
146+
Sampler of the detection task decides length of dataset checking sum of flag array.
147+
To consider that case, return flag array with length of num_samples.
148+
149+
"""
150+
return np.zeros(self.num_sampels, dtype=np.uint8)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Utils for modules using torch."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from .bs_search_algo import adapt_batch_size
7+
8+
__all__ = ["adapt_batch_size"]

0 commit comments

Comments
 (0)