|
| 1 | +"""Algorithm to find a proper batch size which is fit to current GPU device for tasks using mmcv.""" |
| 2 | + |
| 3 | +# Copyright (C) 2023 Intel Corporation |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | + |
| 6 | +from copy import deepcopy |
| 7 | +from typing import Callable, Dict, List |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +from otx.algorithms.common.adapters.torch.utils import adapt_batch_size as adapt_torch_model_bs |
| 12 | +from otx.algorithms.common.utils.logger import get_logger |
| 13 | + |
| 14 | +logger = get_logger() |
| 15 | + |
| 16 | + |
| 17 | +def _set_value_at_dict_in_dict(target: Dict, key_path: str, value): |
| 18 | + """Set value at dictionary hierarchy structure. |
| 19 | +
|
| 20 | + This function is for setting a value at leaf dictionary node in dictionary hierarchy structure. |
| 21 | + If key doesn't exist in the middle node dictionaray, then make a new dictionary at that and keep going. |
| 22 | + For example, if you want to set value at target["a"]["b"]["c"], then you can call the function as below. |
| 23 | + _set_value_at_dict_in_dict(target, "a.b.c", value) |
| 24 | +
|
| 25 | + Args: |
| 26 | + target (Dict): Target variable. |
| 27 | + key_path (str): Dot delimited dictionary key string. |
| 28 | + value : Value to set. |
| 29 | + """ |
| 30 | + keys = key_path.split(".") |
| 31 | + for key in keys[:-1]: |
| 32 | + if key not in target: |
| 33 | + target[key] = {} |
| 34 | + target = target[key] |
| 35 | + |
| 36 | + target[keys[-1]] = value |
| 37 | + |
| 38 | + |
| 39 | +def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False): |
| 40 | + """Decrease batch size if default batch size isn't fit to current GPU device. |
| 41 | +
|
| 42 | + This function just setup for single iteration training to reduce time for adapting. |
| 43 | + The core part of adapting batch size is done in adapt_batch_size in the torch.utils package. |
| 44 | +
|
| 45 | + Args: |
| 46 | + train_func (Callable): The function to train a model. |
| 47 | + Only cfg, dataset and meta are passed to the function when invoking it. |
| 48 | + cfg: Configuration of a training. |
| 49 | + meta (Dict): A dict records some meta information of a training. |
| 50 | + datasets (List): List of datasets. |
| 51 | + validate (bool): Whether do vlidation or not. |
| 52 | + """ |
| 53 | + |
| 54 | + def train_func_single_iter(batch_size): |
| 55 | + copied_cfg = deepcopy(cfg) |
| 56 | + _set_batch_size(copied_cfg, batch_size) |
| 57 | + |
| 58 | + # setup for training a single iter to reduce time |
| 59 | + if copied_cfg.runner.get("type") == "AccuracyAwareRunner": # nncf case |
| 60 | + if "nncf_config" in copied_cfg.runner: |
| 61 | + _set_value_at_dict_in_dict( |
| 62 | + copied_cfg.runner["nncf_config"], "accuracy_aware_training.params.maximal_total_epochs", 1 |
| 63 | + ) |
| 64 | + else: |
| 65 | + copied_cfg.runner["max_epochs"] = 1 |
| 66 | + |
| 67 | + if not validate: # disable validation |
| 68 | + for hook in copied_cfg.custom_hooks: |
| 69 | + if hook["type"] == "AdaptiveTrainSchedulingHook": |
| 70 | + hook["enable_eval_before_run"] = False |
| 71 | + |
| 72 | + new_datasets = [SubDataset(datasets[0], batch_size)] |
| 73 | + |
| 74 | + train_func( |
| 75 | + dataset=new_datasets, |
| 76 | + cfg=copied_cfg, |
| 77 | + validate=validate, |
| 78 | + ) |
| 79 | + |
| 80 | + default_bs = _get_batch_size(cfg) |
| 81 | + |
| 82 | + available_bs = adapt_torch_model_bs( |
| 83 | + train_func=train_func_single_iter, |
| 84 | + current_bs=default_bs, |
| 85 | + trainset_size=len(datasets[0]), |
| 86 | + ) |
| 87 | + |
| 88 | + if default_bs != available_bs: |
| 89 | + _set_batch_size(cfg, available_bs) |
| 90 | + origin_lr = cfg.optimizer.lr |
| 91 | + cfg.optimizer.lr *= available_bs / default_bs |
| 92 | + |
| 93 | + logger.info("Adapting batch size is done.") |
| 94 | + logger.info(f"Batch size is adapted : {default_bs} -> {available_bs}") |
| 95 | + logger.info(f"learning rate is adapted : {origin_lr} -> {cfg.optimizer.lr}") |
| 96 | + else: |
| 97 | + logger.info("Adapting batch size is done. Current batch size is availble.") |
| 98 | + |
| 99 | + |
| 100 | +def _get_batch_size(cfg) -> int: |
| 101 | + if "action" in str(cfg.domain).lower(): |
| 102 | + return cfg.data.videos_per_gpu |
| 103 | + return cfg.data.train_dataloader["samples_per_gpu"] |
| 104 | + |
| 105 | + |
| 106 | +def _set_batch_size(cfg, batch_size: int): |
| 107 | + if "action" in str(cfg.domain).lower(): |
| 108 | + cfg.data.videos_per_gpu = batch_size |
| 109 | + else: |
| 110 | + cfg.data.train_dataloader["samples_per_gpu"] = batch_size |
| 111 | + |
| 112 | + |
| 113 | +class SubDataset: |
| 114 | + """Wrapper class to make dataset pretend to have specified number of images. |
| 115 | +
|
| 116 | + Args: |
| 117 | + fullset: Original dataset. |
| 118 | + num_samples (int): Number of images to pretend to have. It should be positive. |
| 119 | + """ |
| 120 | + |
| 121 | + def __init__(self, fullset, num_sampels: int): |
| 122 | + if num_sampels <= 0: |
| 123 | + raise ValueError(f"num_sampels should be positive. But, current value is {num_sampels}.") |
| 124 | + |
| 125 | + self.fullset = fullset |
| 126 | + self.num_sampels = num_sampels |
| 127 | + |
| 128 | + def __len__(self) -> int: |
| 129 | + """Get length of subset.""" |
| 130 | + return self.num_sampels |
| 131 | + |
| 132 | + def __getitem__(self, indx) -> dict: |
| 133 | + """Get dataset at index.""" |
| 134 | + return self.fullset[indx] |
| 135 | + |
| 136 | + def __getattr__(self, name): |
| 137 | + """When trying to get other attributes, not dataset, get values from fullset.""" |
| 138 | + if name == "__setstate__": |
| 139 | + raise AttributeError(name) |
| 140 | + return getattr(self.fullset, name) |
| 141 | + |
| 142 | + @property |
| 143 | + def flag(self): |
| 144 | + """Getter of flag for detection task. |
| 145 | +
|
| 146 | + Sampler of the detection task decides length of dataset checking sum of flag array. |
| 147 | + To consider that case, return flag array with length of num_samples. |
| 148 | +
|
| 149 | + """ |
| 150 | + return np.zeros(self.num_sampels, dtype=np.uint8) |
0 commit comments