Skip to content

Commit 00bc6a8

Browse files
authored
Introducing adaptive LR schedule to Classification task (#2268)
* Change the params * Fix lr scheduler * Make faster training * Final values * Fix tests * Fix self-sl error * Change the interpolation type: bicubic --> default
1 parent 1b11fbf commit 00bc6a8

File tree

16 files changed

+95
-114
lines changed

16 files changed

+95
-114
lines changed

src/otx/algorithms/classification/adapters/mmcls/task.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@
4747
build_data_parallel,
4848
get_configs_by_pairs,
4949
patch_data_pipeline,
50+
patch_from_hyperparams,
5051
)
5152
from otx.algorithms.common.adapters.mmcv.utils import (
5253
build_dataloader as otx_build_dataloader,
5354
)
5455
from otx.algorithms.common.adapters.mmcv.utils import build_dataset as otx_build_dataset
5556
from otx.algorithms.common.adapters.mmcv.utils.config_utils import (
5657
MPAConfig,
57-
get_adaptive_num_workers,
5858
update_or_add_custom_hook,
5959
)
6060
from otx.algorithms.common.configs.configuration_enums import BatchSizeAdaptType
@@ -117,7 +117,7 @@ def _init_task(self, export: bool = False): # noqa
117117
patch_data_pipeline(self._recipe_cfg, self.data_pipeline_path)
118118

119119
if not export:
120-
self._recipe_cfg.merge_from_dict(self._init_hparam())
120+
patch_from_hyperparams(self._recipe_cfg, self._hyperparams)
121121

122122
if "custom_hooks" in self.override_configs:
123123
override_custom_hooks = self.override_configs.pop("custom_hooks")
@@ -656,59 +656,6 @@ def patch_input_shape(deploy_cfg):
656656

657657
return deploy_cfg
658658

659-
def _init_hparam(self) -> dict:
660-
params = self._hyperparams.learning_parameters
661-
warmup_iters = int(params.learning_rate_warmup_iters)
662-
if self._multilabel:
663-
# hack to use 1cycle policy
664-
lr_config = ConfigDict(max_lr=params.learning_rate, warmup=None)
665-
else:
666-
lr_config = (
667-
ConfigDict(warmup_iters=warmup_iters) if warmup_iters > 0 else ConfigDict(warmup_iters=0, warmup=None)
668-
)
669-
670-
early_stop = False
671-
if self._recipe_cfg is not None:
672-
if params.enable_early_stopping and self._recipe_cfg.get("evaluation", None):
673-
early_stop = ConfigDict(
674-
start=int(params.early_stop_start),
675-
patience=int(params.early_stop_patience),
676-
iteration_patience=int(params.early_stop_iteration_patience),
677-
)
678-
679-
if self._recipe_cfg.runner.get("type").startswith("IterBasedRunner"): # type: ignore
680-
runner = ConfigDict(max_iters=int(params.num_iters))
681-
else:
682-
runner = ConfigDict(max_epochs=int(params.num_iters))
683-
684-
config = ConfigDict(
685-
optimizer=ConfigDict(lr=params.learning_rate),
686-
lr_config=lr_config,
687-
early_stop=early_stop,
688-
data=ConfigDict(
689-
samples_per_gpu=int(params.batch_size),
690-
workers_per_gpu=int(params.num_workers),
691-
),
692-
runner=runner,
693-
)
694-
695-
if self._hyperparams.learning_parameters.auto_num_workers:
696-
adapted_num_worker = get_adaptive_num_workers()
697-
if adapted_num_worker is not None:
698-
config.data.workers_per_gpu = adapted_num_worker
699-
700-
if self._train_type.value == "Semisupervised":
701-
unlabeled_config = ConfigDict(
702-
data=ConfigDict(
703-
unlabeled_dataloader=ConfigDict(
704-
samples_per_gpu=int(params.unlabeled_batch_size),
705-
workers_per_gpu=int(params.num_workers),
706-
)
707-
)
708-
)
709-
config.update(unlabeled_config)
710-
return config
711-
712659
# This should be removed
713660
def update_override_configurations(self, config):
714661
"""Update override_configs."""

src/otx/algorithms/classification/configs/base/data/data_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@
2020
__resize_target_size = 224
2121

2222
__train_pipeline = [
23-
dict(type="Resize", size=__resize_target_size),
23+
dict(type="RandomResizedCrop", size=224, efficientnet_style=True),
2424
dict(type="RandomFlip", flip_prob=0.5, direction="horizontal"),
25-
dict(type="AugMixAugment", config_str="augmix-m5-w3-d1"),
26-
dict(type="RandomRotate", p=0.35, angle=(-10, 10)),
27-
dict(type="PILImageToNDArray", keys=["img"]),
2825
dict(type="Normalize", **__img_norm_cfg),
2926
dict(type="ImageToTensor", keys=["img"]),
3027
dict(type="ToTensor", keys=["gt_label"]),
3128
dict(type="Collect", keys=["img", "gt_label"]),
3229
]
3330

3431
__test_pipeline = [
35-
dict(type="Resize", size=__resize_target_size),
32+
dict(type="Resize", size=(256, -1)),
33+
dict(type="CenterCrop", crop_size=224),
3634
dict(type="Normalize", **__img_norm_cfg),
3735
dict(type="ImageToTensor", keys=["img"]),
3836
dict(type="Collect", keys=["img"]),

src/otx/algorithms/classification/configs/configuration.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ learning_parameters:
171171
visible_in_ui: false
172172
early_stop_patience:
173173
affects_outcome_of: TRAINING
174-
default_value: 8
174+
default_value: 3
175175
description: Training will stop if the model does not improve within the number of epochs of patience.
176176
editable: true
177177
header: Patience for early stopping
@@ -207,17 +207,17 @@ learning_parameters:
207207
warning: This is applied exclusively when early stopping is enabled.
208208
use_adaptive_interval:
209209
affects_outcome_of: TRAINING
210-
default_value: false
210+
default_value: true
211211
description: Depending on the size of iteration per epoch, adaptively update the validation interval and related values.
212-
editable: false
212+
editable: true
213213
header: Use adaptive validation interval
214214
type: BOOLEAN
215215
ui_rules:
216216
action: DISABLE_EDITING
217217
operator: AND
218218
rules: []
219219
type: UI_RULES
220-
visible_in_ui: false
220+
visible_in_ui: true
221221
warning: This will automatically control the patience and interval when early stopping is enabled.
222222
enable_supcon:
223223
affects_outcome_of: TRAINING

src/otx/algorithms/classification/configs/efficientnet_b0_cls_incr/selfsl/hparam.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ hyper_parameters:
1616
default_value: 5000
1717
enable_early_stopping:
1818
default_value: false
19+
use_adaptive_interval:
20+
default_value: false
1921
algo_backend:
2022
train_type:
2123
default_value: Selfsupervised

src/otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/selfsl/hparam.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ hyper_parameters:
1616
default_value: 5000
1717
enable_early_stopping:
1818
default_value: false
19+
use_adaptive_interval:
20+
default_value: false
1921
algo_backend:
2022
train_type:
2123
default_value: Selfsupervised

src/otx/algorithms/classification/configs/mobilenet_v3_large_075_cls_incr/selfsl/hparam.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ hyper_parameters:
1616
default_value: 5000
1717
enable_early_stopping:
1818
default_value: false
19+
use_adaptive_interval:
20+
default_value: false
1921
algo_backend:
2022
train_type:
2123
default_value: Selfsupervised

src/otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/selfsl/hparam.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ hyper_parameters:
1616
default_value: 5000
1717
enable_early_stopping:
1818
default_value: false
19+
use_adaptive_interval:
20+
default_value: false
1921
algo_backend:
2022
train_type:
2123
default_value: Selfsupervised

src/otx/algorithms/classification/configs/mobilenet_v3_small_cls_incr/selfsl/hparam.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ hyper_parameters:
1616
default_value: 5000
1717
enable_early_stopping:
1818
default_value: false
19+
use_adaptive_interval:
20+
default_value: false
1921
algo_backend:
2022
train_type:
2123
default_value: Selfsupervised

src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,22 @@
2323
class AdaptiveTrainSchedulingHook(Hook):
2424
"""Adaptive Training Scheduling Hook.
2525
26-
Depending on the size of iteration per epoch, adaptively update the validation interval and related values.
26+
Depending on the size of iteration per epoch, adaptively update the validation interval.
2727
2828
Args:
2929
max_interval (int): Maximum value of validation interval.
3030
Defaults to 5.
31-
base_lr_patience (int): The value of LR drop patience are expected in total epoch.
32-
Patience used when interval is 1, Defaults to 5.
33-
min_lr_patience (int): Minumum value of LR drop patience.
34-
Defaults to 2.
35-
base_es_patience (int): The value of Early-Stopping patience are expected in total epoch.
36-
Patience used when interval is 1, Defaults to 10.
31+
decay (float): Parameter to control the interval. This value is set by manual manner.
32+
Defaults to -0.025.
33+
enable_adaptive_interval_hook (bool): If True, adaptive interval will be enabled.
34+
Defaults to False.
35+
enable_eval_before_run (bool): If True, initial evaluation before training will be enabled.
36+
Defaults to False.
3737
"""
3838

3939
def __init__(
4040
self,
4141
max_interval=5,
42-
base_lr_patience=5,
43-
min_lr_patience=2,
44-
base_es_patience=10,
45-
min_es_patience=3,
4642
decay=-0.025,
4743
enable_adaptive_interval_hook=False,
4844
enable_eval_before_run=False,
@@ -51,10 +47,6 @@ def __init__(
5147
super().__init__(**kwargs)
5248

5349
self.max_interval = max_interval
54-
self.base_lr_patience = base_lr_patience
55-
self.min_lr_patience = min_lr_patience
56-
self.base_es_patience = base_es_patience
57-
self.min_es_patience = min_es_patience
5850
self.decay = decay
5951
self.enable_adaptive_interval_hook = enable_adaptive_interval_hook
6052
self.enable_eval_before_run = enable_eval_before_run
@@ -92,23 +84,13 @@ def before_train_iter(self, runner):
9284
logger.info(f"Update EvalHook interval: {hook.interval} -> {adaptive_interval}")
9385
hook.interval = adaptive_interval
9486
elif isinstance(hook, LrUpdaterHook):
95-
patience = max(
96-
math.ceil((self.base_lr_patience / adaptive_interval)),
97-
self.min_lr_patience,
98-
)
9987
if hasattr(hook, "interval") and hasattr(hook, "patience"):
10088
hook.interval = adaptive_interval
101-
hook.patience = patience
102-
logger.info(f"Update LrUpdaterHook patience: {hook.patience} -> {patience}")
89+
logger.info(f"Update LrUpdaterHook interval: {hook.interval} -> {adaptive_interval}")
10390
elif isinstance(hook, EarlyStoppingHook):
104-
patience = max(
105-
math.ceil((self.base_es_patience / adaptive_interval)),
106-
self.min_es_patience,
107-
)
108-
logger.info(f"Update EarlyStoppingHook patience: {hook.patience} -> {patience}")
91+
logger.info(f"Update EarlyStoppingHook interval: {hook.interval} -> {adaptive_interval}")
10992
hook.start = adaptive_interval
11093
hook.interval = adaptive_interval
111-
hook.patience = patience
11294
elif isinstance(hook, CheckpointHook):
11395
# make sure checkpoint is saved at last
11496
limit = runner.max_epochs if hook.by_epoch else runner.max_iters

src/otx/algorithms/common/adapters/mmcv/hooks/early_stopping_hook.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class EarlyStoppingHook(Hook):
4040
continues if the number of iteration is lower than iteration_patience
4141
This variable makes sure a model is trained enough for some
4242
iterations after the last improvement before stopping.
43-
:param min_delta: Minimal decay applied to lr. If the difference between new and old lr is
44-
smaller than eps, the update is ignored
43+
:param min_delta_ratio: Minimal ratio value to check the best score. If the difference between current and
44+
best score is smaller than (current_score * (1-min_delta_ratio)), best score will not be changed.
4545
"""
4646

4747
rule_map = {"greater": lambda x, y: x > y, "less": lambda x, y: x < y}
@@ -56,16 +56,16 @@ def __init__(
5656
rule: Optional[str] = None,
5757
patience: int = 5,
5858
iteration_patience: int = 500,
59-
min_delta: float = 0.0,
59+
min_delta_ratio: float = 0.0,
6060
):
6161
super().__init__()
6262
self.patience = patience
6363
self.iteration_patience = iteration_patience
6464
self.interval = interval
65-
self.min_delta = min_delta
65+
self.min_delta_ratio = min_delta_ratio
6666
self._init_rule(rule, metric)
6767

68-
self.min_delta *= 1 if self.rule == "greater" else -1
68+
self.min_delta_ratio *= 1 if self.rule == "greater" else -1
6969
self.last_iter = 0
7070
self.wait_count = 0
7171
self.best_score = self.init_value_map[self.rule]
@@ -141,7 +141,7 @@ def _do_check_stopping(self, runner):
141141
)
142142

143143
key_score = runner.log_buffer.output[self.key_indicator]
144-
if self.compare_func(key_score - self.min_delta, self.best_score):
144+
if self.compare_func(key_score - (key_score * self.min_delta_ratio), self.best_score):
145145
self.best_score = key_score
146146
self.wait_count = 0
147147
self.last_iter = runner.iter
@@ -184,11 +184,11 @@ def __init__(
184184
rule: str = None,
185185
patience: int = 5,
186186
iteration_patience: int = 500,
187-
min_delta: float = 0.0,
187+
min_delta_ratio: float = 0.0,
188188
start: int = None,
189189
):
190190
self.start = start
191-
super().__init__(interval, metric, rule, patience, iteration_patience, min_delta)
191+
super().__init__(interval, metric, rule, patience, iteration_patience, min_delta_ratio)
192192

193193
def _should_check_stopping(self, runner):
194194
if self.by_epoch:
@@ -352,6 +352,7 @@ def get_lr(self, runner: BaseRunner, base_lr: float):
352352
logger=runner.logger,
353353
)
354354
return self.current_lr
355+
355356
self.last_iter = runner.iter
356357
self.bad_count = 0
357358
print_log(

0 commit comments

Comments
 (0)