Skip to content

Commit 1d197c0

Browse files
committed
Add lr scheduler
1 parent e0939cf commit 1d197c0

File tree

1 file changed

+85
-2
lines changed

1 file changed

+85
-2
lines changed

celldetection/optim/lr_scheduler.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from torch.optim.lr_scheduler import MultiplicativeLR
1+
from torch.optim.lr_scheduler import MultiplicativeLR, SequentialLR as _SequentialLR, \
2+
ReduceLROnPlateau as _ReduceLROnPlateau
23
from torch.optim import Optimizer
34
from typing import Union, Callable, List
45
import warnings
6+
from bisect import bisect_right
57

6-
__all__ = ['WarmUp']
8+
from ..util.util import has_argument
9+
10+
__all__ = ['WarmUp', 'SequentialLR', 'ReduceLROnPlateau']
711

812

913
def linear_schedule(step, steps):
@@ -68,3 +72,82 @@ def get_lr(self):
6872
if self.last_epoch <= self.steps:
6973
return [lr * lmbda(self.last_epoch, self.steps) for lmbda, lr in zip(self.lr_lambdas, self.base_lrs)]
7074
return [group['lr'] for group in self.optimizer.param_groups]
75+
76+
77+
class SequentialLR(_SequentialLR):
78+
79+
def step(self, metrics=None): # fixes TypeError caused by use of metric
80+
self.last_epoch += 1
81+
idx = bisect_right(self._milestones, self.last_epoch)
82+
scheduler = self._schedulers[idx]
83+
if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
84+
scheduler.step(0)
85+
else:
86+
87+
if metrics is None:
88+
scheduler.step()
89+
else:
90+
# Some schedulers require metrics, others do not. If not handled here, it'll raise a TypeError
91+
if has_argument(scheduler.step, 'metric', 'metrics', mode='any'):
92+
scheduler.step(metrics)
93+
else:
94+
scheduler.step()
95+
96+
self._last_lr = scheduler.get_last_lr()
97+
98+
99+
class ReduceLROnPlateau(_ReduceLROnPlateau):
100+
def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
101+
threshold=1e-4, threshold_mode='rel', cooldown=0,
102+
min_lr=0, eps=1e-8, warmup=1, verbose="deprecated"):
103+
"""
104+
Initializes the ReduceLROnPlateau object. This scheduler decreases the learning rate
105+
when a metric has stopped improving, which is commonly used to fine-tune a model in
106+
machine learning.
107+
108+
Notes:
109+
- Adds the warmup option to PyTorch's ``ReduceLROnPlateau``.
110+
111+
Args:
112+
optimizer (Optimizer): Wrapped optimizer.
113+
mode (str): One of `min` or `max`. In `min` mode, the learning rate will be reduced
114+
when the quantity monitored has stopped decreasing; in `max` mode, it will
115+
be reduced when the quantity monitored has stopped increasing. Default: 'min'.
116+
factor (float): Factor by which the learning rate will be reduced. `new_lr = lr * factor`.
117+
Default: 0.1.
118+
patience (int): Number of epochs with no improvement after which learning rate will be
119+
reduced. Default: 10.
120+
threshold (float): Threshold for measuring the new optimum, to only focus on significant
121+
changes. Default: 1e-4.
122+
threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold)
123+
in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode,
124+
dynamic_threshold = best + threshold in `max` mode or best - threshold in
125+
`min` mode. Default: 'rel'.
126+
cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
127+
Default: 0.
128+
min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of
129+
all param groups or each group respectively. Default: 0.
130+
eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller
131+
than eps, the update is ignored. Default: 1e-8.
132+
warmup (int): Number of epochs to wait before initially starting normal operation. Default: 1.
133+
verbose (str): Deprecated argument. Not used. Default: "deprecated".
134+
"""
135+
super().__init__(optimizer=optimizer, mode=mode, factor=factor, patience=patience,
136+
threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown,
137+
min_lr=min_lr, eps=eps, verbose=verbose)
138+
self.warmup_counter = int(warmup) # ignores bad epochs for the first number of `warmup` steps
139+
140+
def get_last_lr(self): # required by PyTorch functions to be implemented right here
141+
return self._last_lr
142+
143+
def step(self, metrics, epoch=None):
144+
best_ = None
145+
if self.warmup_counter:
146+
self.warmup_counter -= 1
147+
best_ = self.best
148+
149+
res = super().step(metrics, epoch)
150+
if best_ is not None:
151+
self.best = best_
152+
self.num_bad_epochs = 0 # ignore any bad epochs in warmup
153+
return res

0 commit comments

Comments
 (0)