2323class AdaptiveTrainSchedulingHook (Hook ):
2424 """Adaptive Training Scheduling Hook.
2525
26- Depending on the size of iteration per epoch, adaptively update the validation interval.
26+ Depending on the size of iteration per epoch, adaptively update the validation interval and related values .
2727
2828 Args:
29+ base_lr_patience (int): The value of LR drop patience are expected in total epoch.
30+ Patience used when interval is 1, Defaults to 5.
31+ min_lr_patience (int): Minumum value of LR drop patience.
32+ Defaults to 2.
33+ base_es_patience (int): The value of Early-Stopping patience are expected in total epoch.
34+ Patience used when interval is 1, Defaults to 10.
2935 max_interval (int): Maximum value of validation interval.
3036 Defaults to 5.
3137 decay (float): Parameter to control the interval. This value is set by manual manner.
@@ -39,6 +45,10 @@ class AdaptiveTrainSchedulingHook(Hook):
3945 def __init__ (
4046 self ,
4147 max_interval = 5 ,
48+ base_lr_patience = 5 ,
49+ min_lr_patience = 2 ,
50+ base_es_patience = 10 ,
51+ min_es_patience = 3 ,
4252 decay = - 0.025 ,
4353 enable_adaptive_interval_hook = False ,
4454 enable_eval_before_run = False ,
@@ -47,6 +57,10 @@ def __init__(
4757 super ().__init__ (** kwargs )
4858
4959 self .max_interval = max_interval
60+ self .base_lr_patience = base_lr_patience
61+ self .min_lr_patience = min_lr_patience
62+ self .base_es_patience = base_es_patience
63+ self .min_es_patience = min_es_patience
5064 self .decay = decay
5165 self .enable_adaptive_interval_hook = enable_adaptive_interval_hook
5266 self .enable_eval_before_run = enable_eval_before_run
@@ -84,13 +98,23 @@ def before_train_iter(self, runner):
8498 logger .info (f"Update EvalHook interval: { hook .interval } -> { adaptive_interval } " )
8599 hook .interval = adaptive_interval
86100 elif isinstance (hook , LrUpdaterHook ):
101+ patience = max (
102+ math .ceil ((self .base_lr_patience / adaptive_interval )),
103+ self .min_lr_patience ,
104+ )
87105 if hasattr (hook , "interval" ) and hasattr (hook , "patience" ):
88106 hook .interval = adaptive_interval
89- logger .info (f"Update LrUpdaterHook interval: { hook .interval } -> { adaptive_interval } " )
107+ hook .patience = patience
108+ logger .info (f"Update LrUpdaterHook patience: { hook .patience } -> { patience } " )
90109 elif isinstance (hook , EarlyStoppingHook ):
91- logger .info (f"Update EarlyStoppingHook interval: { hook .interval } -> { adaptive_interval } " )
110+ patience = max (
111+ math .ceil ((self .base_es_patience / adaptive_interval )),
112+ self .min_es_patience ,
113+ )
114+ logger .info (f"Update EarlyStoppingHook patience: { hook .patience } -> { patience } " )
92115 hook .start = adaptive_interval
93116 hook .interval = adaptive_interval
117+ hook .patience = patience
94118 elif isinstance (hook , CheckpointHook ):
95119 # make sure checkpoint is saved at last
96120 limit = runner .max_epochs if hook .by_epoch else runner .max_iters
0 commit comments