|
1 | | -from torch.optim.lr_scheduler import MultiplicativeLR |
| 1 | +from torch.optim.lr_scheduler import MultiplicativeLR, SequentialLR as _SequentialLR, \ |
| 2 | + ReduceLROnPlateau as _ReduceLROnPlateau |
2 | 3 | from torch.optim import Optimizer |
3 | 4 | from typing import Union, Callable, List |
4 | 5 | import warnings |
| 6 | +from bisect import bisect_right |
5 | 7 |
|
6 | | -__all__ = ['WarmUp'] |
| 8 | +from ..util.util import has_argument |
| 9 | + |
| 10 | +__all__ = ['WarmUp', 'SequentialLR', 'ReduceLROnPlateau'] |
7 | 11 |
|
8 | 12 |
|
9 | 13 | def linear_schedule(step, steps): |
@@ -68,3 +72,82 @@ def get_lr(self): |
68 | 72 | if self.last_epoch <= self.steps: |
69 | 73 | return [lr * lmbda(self.last_epoch, self.steps) for lmbda, lr in zip(self.lr_lambdas, self.base_lrs)] |
70 | 74 | return [group['lr'] for group in self.optimizer.param_groups] |
| 75 | + |
| 76 | + |
| 77 | +class SequentialLR(_SequentialLR): |
| 78 | + |
| 79 | + def step(self, metrics=None): # fixes TypeError caused by use of metric |
| 80 | + self.last_epoch += 1 |
| 81 | + idx = bisect_right(self._milestones, self.last_epoch) |
| 82 | + scheduler = self._schedulers[idx] |
| 83 | + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: |
| 84 | + scheduler.step(0) |
| 85 | + else: |
| 86 | + |
| 87 | + if metrics is None: |
| 88 | + scheduler.step() |
| 89 | + else: |
| 90 | + # Some schedulers require metrics, others do not. If not handled here, it'll raise a TypeError |
| 91 | + if has_argument(scheduler.step, 'metric', 'metrics', mode='any'): |
| 92 | + scheduler.step(metrics) |
| 93 | + else: |
| 94 | + scheduler.step() |
| 95 | + |
| 96 | + self._last_lr = scheduler.get_last_lr() |
| 97 | + |
| 98 | + |
| 99 | +class ReduceLROnPlateau(_ReduceLROnPlateau): |
| 100 | + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, |
| 101 | + threshold=1e-4, threshold_mode='rel', cooldown=0, |
| 102 | + min_lr=0, eps=1e-8, warmup=1, verbose="deprecated"): |
| 103 | + """ |
| 104 | + Initializes the ReduceLROnPlateau object. This scheduler decreases the learning rate |
| 105 | + when a metric has stopped improving, which is commonly used to fine-tune a model in |
| 106 | + machine learning. |
| 107 | +
|
| 108 | + Notes: |
| 109 | + - Adds the warmup option to PyTorch's ``ReduceLROnPlateau``. |
| 110 | +
|
| 111 | + Args: |
| 112 | + optimizer (Optimizer): Wrapped optimizer. |
| 113 | + mode (str): One of `min` or `max`. In `min` mode, the learning rate will be reduced |
| 114 | + when the quantity monitored has stopped decreasing; in `max` mode, it will |
| 115 | + be reduced when the quantity monitored has stopped increasing. Default: 'min'. |
| 116 | + factor (float): Factor by which the learning rate will be reduced. `new_lr = lr * factor`. |
| 117 | + Default: 0.1. |
| 118 | + patience (int): Number of epochs with no improvement after which learning rate will be |
| 119 | + reduced. Default: 10. |
| 120 | + threshold (float): Threshold for measuring the new optimum, to only focus on significant |
| 121 | + changes. Default: 1e-4. |
| 122 | + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) |
| 123 | + in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, |
| 124 | + dynamic_threshold = best + threshold in `max` mode or best - threshold in |
| 125 | + `min` mode. Default: 'rel'. |
| 126 | + cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced. |
| 127 | + Default: 0. |
| 128 | + min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of |
| 129 | + all param groups or each group respectively. Default: 0. |
| 130 | + eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller |
| 131 | + than eps, the update is ignored. Default: 1e-8. |
| 132 | + warmup (int): Number of epochs to wait before initially starting normal operation. Default: 1. |
| 133 | + verbose (str): Deprecated argument. Not used. Default: "deprecated". |
| 134 | + """ |
| 135 | + super().__init__(optimizer=optimizer, mode=mode, factor=factor, patience=patience, |
| 136 | + threshold=threshold, threshold_mode=threshold_mode, cooldown=cooldown, |
| 137 | + min_lr=min_lr, eps=eps, verbose=verbose) |
| 138 | + self.warmup_counter = int(warmup) # ignores bad epochs for the first number of `warmup` steps |
| 139 | + |
| 140 | + def get_last_lr(self): # required by PyTorch functions to be implemented right here |
| 141 | + return self._last_lr |
| 142 | + |
| 143 | + def step(self, metrics, epoch=None): |
| 144 | + best_ = None |
| 145 | + if self.warmup_counter: |
| 146 | + self.warmup_counter -= 1 |
| 147 | + best_ = self.best |
| 148 | + |
| 149 | + res = super().step(metrics, epoch) |
| 150 | + if best_ is not None: |
| 151 | + self.best = best_ |
| 152 | + self.num_bad_epochs = 0 # ignore any bad epochs in warmup |
| 153 | + return res |
0 commit comments