|
14 | 14 | # See the License for the specific language governing permissions |
15 | 15 | # and limitations under the License. |
16 | 16 |
|
17 | | -import math |
18 | | -from typing import List, Optional, Union |
| 17 | +from typing import List, Optional |
19 | 18 |
|
20 | 19 | from mmcv import Config, ConfigDict |
21 | 20 |
|
22 | 21 | from otx.algorithms.common.adapters.mmcv.utils import ( |
23 | | - get_configs_by_keys, |
24 | 22 | get_configs_by_pairs, |
25 | 23 | get_dataset_configs, |
26 | 24 | get_meta_keys, |
27 | | - is_epoch_based_runner, |
28 | 25 | patch_color_conversion, |
29 | | - prepare_work_dir, |
30 | | - remove_from_config, |
31 | | - remove_from_configs_by_type, |
32 | | - update_config, |
33 | | -) |
34 | | -from otx.api.entities.label import Domain, LabelEntity |
35 | | -from otx.api.utils.argument_checks import ( |
36 | | - DirectoryPathCheck, |
37 | | - check_input_parameters_type, |
38 | 26 | ) |
| 27 | +from otx.api.entities.label import Domain |
| 28 | +from otx.api.utils.argument_checks import check_input_parameters_type |
39 | 29 | from otx.mpa.utils.logger import get_logger |
40 | 30 |
|
41 | 31 | logger = get_logger() |
42 | 32 |
|
43 | 33 |
|
44 | | -@check_input_parameters_type({"work_dir": DirectoryPathCheck}) |
45 | | -def patch_config( |
46 | | - config: Config, |
47 | | - work_dir: str, |
48 | | - labels: List[LabelEntity], |
49 | | -): # pylint: disable=too-many-branches |
50 | | - """Update config function.""" |
51 | | - |
52 | | - # Add training cancelation hook. |
53 | | - if "custom_hooks" not in config: |
54 | | - config.custom_hooks = [] |
55 | | - if "CancelTrainingHook" not in {hook.type for hook in config.custom_hooks}: |
56 | | - config.custom_hooks.append(ConfigDict({"type": "CancelTrainingHook"})) |
57 | | - |
58 | | - # Remove high level data pipelines definition leaving them only inside `data` section. |
59 | | - remove_from_config(config, "train_pipeline") |
60 | | - remove_from_config(config, "test_pipeline") |
61 | | - remove_from_config(config, "train_pipeline_strong") |
62 | | - # Remove cancel interface hook |
63 | | - remove_from_configs_by_type(config.custom_hooks, "CancelInterfaceHook") |
64 | | - |
65 | | - config.checkpoint_config.max_keep_ckpts = 5 |
66 | | - config.checkpoint_config.interval = config.evaluation.get("interval", 1) |
67 | | - |
68 | | - set_data_classes(config, labels) |
69 | | - |
70 | | - config.gpu_ids = range(1) |
71 | | - config.work_dir = work_dir |
72 | | - |
73 | | - |
74 | | -@check_input_parameters_type() |
75 | | -def patch_model_config( |
76 | | - config: Config, |
77 | | - labels: List[LabelEntity], |
78 | | -): |
79 | | - """Patch model config.""" |
80 | | - set_num_classes(config, len(labels)) |
81 | | - |
82 | | - |
83 | | -@check_input_parameters_type() |
84 | | -def patch_adaptive_repeat_dataset( |
85 | | - config: Union[Config, ConfigDict], |
86 | | - num_samples: int, |
87 | | - decay: float = -0.002, |
88 | | - factor: float = 30, |
89 | | -): |
90 | | - """Patch the repeat times and training epochs adatively. |
91 | | -
|
92 | | - Frequent dataloading inits and evaluation slow down training when the |
93 | | - sample size is small. Adjusting epoch and dataset repetition based on |
94 | | - empirical exponential decay improves the training time by applying high |
95 | | - repeat value to small sample size dataset and low repeat value to large |
96 | | - sample. |
97 | | -
|
98 | | - :param config: mmcv config |
99 | | - :param num_samples: number of training samples |
100 | | - :param decay: decaying rate |
101 | | - :param factor: base repeat factor |
102 | | - """ |
103 | | - data_train = config.data.train |
104 | | - if data_train.type == "RepeatDataset" and getattr(data_train, "adaptive_repeat_times", False): |
105 | | - if is_epoch_based_runner(config.runner): |
106 | | - cur_epoch = config.runner.max_epochs |
107 | | - new_repeat = max(round(math.exp(decay * num_samples) * factor), 1) |
108 | | - new_epoch = math.ceil(cur_epoch / new_repeat) |
109 | | - if new_epoch == 1: |
110 | | - return |
111 | | - config.runner.max_epochs = new_epoch |
112 | | - data_train.times = new_repeat |
113 | | - |
114 | | - |
115 | | -@check_input_parameters_type() |
116 | | -def prepare_for_training( |
117 | | - config: Union[Config, ConfigDict], |
118 | | - data_config: ConfigDict, |
119 | | -) -> Union[Config, ConfigDict]: |
120 | | - """Prepare configs for training phase.""" |
121 | | - prepare_work_dir(config) |
122 | | - |
123 | | - train_num_samples = 0 |
124 | | - for subset in ["train", "val", "test"]: |
125 | | - data_config_ = data_config.data.get(subset) |
126 | | - config_ = config.data.get(subset) |
127 | | - if data_config_ is None: |
128 | | - continue |
129 | | - for key in ["otx_dataset", "labels"]: |
130 | | - found = get_configs_by_keys(data_config_, key, return_path=True) |
131 | | - if len(found) == 0: |
132 | | - continue |
133 | | - assert len(found) == 1 |
134 | | - if subset == "train" and key == "otx_dataset": |
135 | | - found_value = list(found.values())[0] |
136 | | - if found_value: |
137 | | - train_num_samples = len(found_value) |
138 | | - update_config(config_, found) |
139 | | - |
140 | | - if train_num_samples > 0: |
141 | | - patch_adaptive_repeat_dataset(config, train_num_samples) |
142 | | - |
143 | | - return config |
144 | | - |
145 | | - |
146 | | -@check_input_parameters_type() |
147 | | -def set_data_classes(config: Config, labels: List[LabelEntity]): |
148 | | - """Setter data classes into config.""" |
149 | | - # Save labels in data configs. |
150 | | - for subset in ("train", "val", "test"): |
151 | | - for cfg in get_dataset_configs(config, subset): |
152 | | - cfg.labels = labels |
153 | | - |
154 | | - |
155 | | -@check_input_parameters_type() |
156 | | -def set_num_classes(config: Config, num_classes: int): |
157 | | - """Set num classes.""" |
158 | | - head_names = ["head"] |
159 | | - for head_name in head_names: |
160 | | - if head_name in config.model: |
161 | | - config.model[head_name].num_classes = num_classes |
162 | | - |
163 | | - |
164 | 34 | @check_input_parameters_type() |
165 | 35 | def patch_datasets( |
166 | 36 | config: Config, |
|
0 commit comments