Skip to content

Commit b152d9e

Browse files
authored
Re introduce adaptive scheduling for training (#2541)
* Re-introduce adaptive patience for training * Revert unit tests
1 parent 4494955 commit b152d9e

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/otx/algorithms/common/adapters/mmcv/hooks/adaptive_training_hook.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@
2323
class AdaptiveTrainSchedulingHook(Hook):
2424
"""Adaptive Training Scheduling Hook.
2525
26-
Depending on the size of iteration per epoch, adaptively update the validation interval.
26+
Depending on the size of iteration per epoch, adaptively update the validation interval and related values.
2727
2828
Args:
29+
base_lr_patience (int): The value of LR drop patience are expected in total epoch.
30+
Patience used when interval is 1, Defaults to 5.
31+
min_lr_patience (int): Minumum value of LR drop patience.
32+
Defaults to 2.
33+
base_es_patience (int): The value of Early-Stopping patience are expected in total epoch.
34+
Patience used when interval is 1, Defaults to 10.
2935
max_interval (int): Maximum value of validation interval.
3036
Defaults to 5.
3137
decay (float): Parameter to control the interval. This value is set by manual manner.
@@ -39,6 +45,10 @@ class AdaptiveTrainSchedulingHook(Hook):
3945
def __init__(
4046
self,
4147
max_interval=5,
48+
base_lr_patience=5,
49+
min_lr_patience=2,
50+
base_es_patience=10,
51+
min_es_patience=3,
4252
decay=-0.025,
4353
enable_adaptive_interval_hook=False,
4454
enable_eval_before_run=False,
@@ -47,6 +57,10 @@ def __init__(
4757
super().__init__(**kwargs)
4858

4959
self.max_interval = max_interval
60+
self.base_lr_patience = base_lr_patience
61+
self.min_lr_patience = min_lr_patience
62+
self.base_es_patience = base_es_patience
63+
self.min_es_patience = min_es_patience
5064
self.decay = decay
5165
self.enable_adaptive_interval_hook = enable_adaptive_interval_hook
5266
self.enable_eval_before_run = enable_eval_before_run
@@ -84,13 +98,23 @@ def before_train_iter(self, runner):
8498
logger.info(f"Update EvalHook interval: {hook.interval} -> {adaptive_interval}")
8599
hook.interval = adaptive_interval
86100
elif isinstance(hook, LrUpdaterHook):
101+
patience = max(
102+
math.ceil((self.base_lr_patience / adaptive_interval)),
103+
self.min_lr_patience,
104+
)
87105
if hasattr(hook, "interval") and hasattr(hook, "patience"):
88106
hook.interval = adaptive_interval
89-
logger.info(f"Update LrUpdaterHook interval: {hook.interval} -> {adaptive_interval}")
107+
hook.patience = patience
108+
logger.info(f"Update LrUpdaterHook patience: {hook.patience} -> {patience}")
90109
elif isinstance(hook, EarlyStoppingHook):
91-
logger.info(f"Update EarlyStoppingHook interval: {hook.interval} -> {adaptive_interval}")
110+
patience = max(
111+
math.ceil((self.base_es_patience / adaptive_interval)),
112+
self.min_es_patience,
113+
)
114+
logger.info(f"Update EarlyStoppingHook patience: {hook.patience} -> {patience}")
92115
hook.start = adaptive_interval
93116
hook.interval = adaptive_interval
117+
hook.patience = patience
94118
elif isinstance(hook, CheckpointHook):
95119
# make sure checkpoint is saved at last
96120
limit = runner.max_epochs if hook.by_epoch else runner.max_iters

tests/unit/algorithms/common/adapters/mmcv/hooks/test_adaptive_training_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_before_train_iter(self) -> None:
8686
assert hook._original_interval is None
8787
assert eval_hook.interval == 4
8888
assert lr_hook.interval == 4
89-
assert lr_hook.patience == 1
89+
assert lr_hook.patience == 2
9090
assert early_hook.interval == 4
91-
assert early_hook.patience == 1
91+
assert early_hook.patience == 3
9292
assert ckpt_hook.interval == 4

0 commit comments

Comments
 (0)