|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.optim.optimizer import Optimizer |
| 5 | + |
| 6 | +from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE |
| 7 | + |
| 8 | + |
| 9 | +class DiffRGrad(Optimizer): |
| 10 | + """ |
| 11 | + Reference 1 : https://github.com/shivram1987/diffGrad |
| 12 | + Reference 2 : https://github.com/LiyuanLucasLiu/RAdam |
| 13 | + Reference 3 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/diffgrad/diff_rgrad.py |
| 14 | + Example : |
| 15 | + from pytorch_optimizer import DiffRGrad |
| 16 | + ... |
| 17 | + model = YourModel() |
| 18 | + optimizer = DiffRGrad(model.parameters()) |
| 19 | + ... |
| 20 | + for input, output in data: |
| 21 | + optimizer.zero_grad() |
| 22 | + loss = loss_function(output, model(input)) |
| 23 | + loss.backward() |
| 24 | + optimizer.step() |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + params: PARAMETERS, |
| 30 | + lr: float = 1e-3, |
| 31 | + betas: BETAS = (0.9, 0.999), |
| 32 | + weight_decay: float = 0.0, |
| 33 | + n_sma_threshold: int = 5, |
| 34 | + degenerated_to_sgd: bool = True, |
| 35 | + eps: float = 1e-8, |
| 36 | + ): |
| 37 | + """Blend RAdam with DiffGrad |
| 38 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups |
| 39 | + :param lr: float. learning rate. |
| 40 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace |
| 41 | + :param weight_decay: float. weight decay (L2 penalty) |
| 42 | + :param n_sma_threshold: int. (recommended is 5) |
| 43 | + :param degenerated_to_sgd: float. |
| 44 | + :param eps: float. term added to the denominator to improve numerical stability |
| 45 | + """ |
| 46 | + self.lr = lr |
| 47 | + self.betas = betas |
| 48 | + self.weight_decay = weight_decay |
| 49 | + self.n_sma_threshold = n_sma_threshold |
| 50 | + self.degenerated_to_sgd = degenerated_to_sgd |
| 51 | + self.eps = eps |
| 52 | + |
| 53 | + self.check_valid_parameters() |
| 54 | + |
| 55 | + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): |
| 56 | + for param in params: |
| 57 | + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): |
| 58 | + param['buffer'] = [[None, None, None] for _ in range(10)] |
| 59 | + |
| 60 | + defaults: DEFAULTS = dict( |
| 61 | + lr=lr, |
| 62 | + betas=betas, |
| 63 | + eps=eps, |
| 64 | + weight_decay=weight_decay, |
| 65 | + buffer=[[None, None, None] for _ in range(10)], |
| 66 | + ) |
| 67 | + |
| 68 | + super().__init__(params, defaults) |
| 69 | + |
| 70 | + def check_valid_parameters(self): |
| 71 | + if self.lr < 0.0: |
| 72 | + raise ValueError(f'Invalid learning rate : {self.lr}') |
| 73 | + if self.weight_decay < 0.0: |
| 74 | + raise ValueError(f'Invalid weight_decay : {self.weight_decay}') |
| 75 | + if not 0.0 <= self.betas[0] < 1.0: |
| 76 | + raise ValueError(f'Invalid beta_0 : {self.betas[0]}') |
| 77 | + if not 0.0 <= self.betas[1] < 1.0: |
| 78 | + raise ValueError(f'Invalid beta_1 : {self.betas[1]}') |
| 79 | + if self.eps < 0.0: |
| 80 | + raise ValueError(f'Invalid eps : {self.eps}') |
| 81 | + |
| 82 | + def __setstate__(self, state: STATE): |
| 83 | + super().__setstate__(state) |
| 84 | + |
| 85 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 86 | + loss: LOSS = None |
| 87 | + if closure is not None: |
| 88 | + loss = closure() |
| 89 | + |
| 90 | + for group in self.param_groups: |
| 91 | + for p in group['params']: |
| 92 | + if p.grad is None: |
| 93 | + continue |
| 94 | + |
| 95 | + grad = p.grad.data.float() |
| 96 | + if grad.is_sparse: |
| 97 | + raise RuntimeError('diffGrad does not support sparse gradients') |
| 98 | + |
| 99 | + p_data_fp32 = p.data.float() |
| 100 | + state = self.state[p] |
| 101 | + |
| 102 | + if len(state) == 0: |
| 103 | + state['step'] = 0 |
| 104 | + state['exp_avg'] = torch.zeros_like(p_data_fp32) |
| 105 | + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) |
| 106 | + state['previous_grad'] = torch.zeros_like(p_data_fp32) |
| 107 | + else: |
| 108 | + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) |
| 109 | + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) |
| 110 | + state['previous_grad'] = state['previous_grad'].type_as(p_data_fp32) |
| 111 | + |
| 112 | + exp_avg, exp_avg_sq, previous_grad = ( |
| 113 | + state['exp_avg'], |
| 114 | + state['exp_avg_sq'], |
| 115 | + state['previous_grad'], |
| 116 | + ) |
| 117 | + beta1, beta2 = group['betas'] |
| 118 | + |
| 119 | + state['step'] += 1 |
| 120 | + |
| 121 | + exp_avg.mul_(beta1).add_(1 - beta1, grad) |
| 122 | + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
| 123 | + |
| 124 | + # compute diffGrad coefficient (dfc) |
| 125 | + diff = abs(previous_grad - grad) |
| 126 | + dfc = 1.0 / (1.0 + torch.exp(-diff)) |
| 127 | + |
| 128 | + state['previous_grad'] = grad.clone() |
| 129 | + |
| 130 | + buffered = group['buffer'][int(state['step'] % 10)] |
| 131 | + if state['step'] == buffered[0]: |
| 132 | + n_sma, step_size = buffered[1], buffered[2] |
| 133 | + else: |
| 134 | + buffered[0] = state['step'] |
| 135 | + beta2_t = beta2 ** state['step'] |
| 136 | + n_sma_max = 2.0 / (1.0 - beta2) - 1.0 |
| 137 | + n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t) |
| 138 | + buffered[1] = n_sma |
| 139 | + |
| 140 | + if n_sma >= self.n_sma_threshold: |
| 141 | + step_size = math.sqrt( |
| 142 | + (1 - beta2_t) |
| 143 | + * (n_sma - 4) |
| 144 | + / (n_sma_max - 4) |
| 145 | + * (n_sma - 2) |
| 146 | + / n_sma |
| 147 | + * n_sma_max |
| 148 | + / (n_sma_max - 2) |
| 149 | + ) / (1.0 - beta1 ** state['step']) |
| 150 | + elif self.degenerated_to_sgd: |
| 151 | + step_size = 1.0 / (1 - beta1 ** state['step']) |
| 152 | + else: |
| 153 | + step_size = -1 |
| 154 | + buffered[2] = step_size |
| 155 | + |
| 156 | + if n_sma >= self.n_sma_threshold: |
| 157 | + if group['weight_decay'] != 0: |
| 158 | + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) |
| 159 | + |
| 160 | + denom = exp_avg_sq.sqrt().add_(group['eps']) |
| 161 | + |
| 162 | + # update momentum with dfc |
| 163 | + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom) |
| 164 | + p.data.copy_(p_data_fp32) |
| 165 | + elif step_size > 0: |
| 166 | + if group['weight_decay'] != 0: |
| 167 | + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) |
| 168 | + |
| 169 | + p_data_fp32.add_(-step_size * group['lr'], exp_avg) |
| 170 | + p.data.copy_(p_data_fp32) |
| 171 | + |
| 172 | + return loss |
0 commit comments