|
| 1 | +import math |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.optim.optimizer import Optimizer |
| 6 | + |
| 7 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 8 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 9 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 10 | + |
| 11 | + |
| 12 | +class Prodigy(Optimizer, BaseOptimizer): |
| 13 | + r"""An Expeditiously Adaptive Parameter-Free Learner. |
| 14 | +
|
| 15 | + Leave LR set to 1 unless you encounter instability. |
| 16 | +
|
| 17 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. |
| 18 | + :param lr: float. learning rate. |
| 19 | + :param betas: BETAS. betas. |
| 20 | + :param beta3: float. coefficients for computing the Prodidy step-size using running averages. If set to None, |
| 21 | + uses the value of square root of beta2. |
| 22 | + :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. |
| 23 | + :param d_coef: float. Coefficient in the expression for the estimate of d. |
| 24 | + :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate. |
| 25 | + :param weight_decay: float. weight decay (L2 penalty). |
| 26 | + :param weight_decouple: bool. use AdamW style weight decay. |
| 27 | + :param fixed_decay: bool. fix weight decay. |
| 28 | + :param bias_correction: bool. turn on Adam's bias correction. |
| 29 | + :param safeguard_warmup: bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage. |
| 30 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + params: PARAMETERS, |
| 36 | + lr: float = 1.0, |
| 37 | + betas: BETAS = (0.9, 0.999), |
| 38 | + beta3: Optional[float] = None, |
| 39 | + d0: float = 1e-6, |
| 40 | + d_coef: float = 1.0, |
| 41 | + growth_rate: float = float('inf'), |
| 42 | + weight_decay: float = 0.0, |
| 43 | + weight_decouple: bool = True, |
| 44 | + fixed_decay: bool = False, |
| 45 | + bias_correction: bool = False, |
| 46 | + safeguard_warmup: bool = False, |
| 47 | + eps: float = 1e-8, |
| 48 | + ): |
| 49 | + self.validate_learning_rate(lr) |
| 50 | + self.validate_betas((*betas, beta3)) |
| 51 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 52 | + self.validate_non_negative(eps, 'eps') |
| 53 | + |
| 54 | + defaults: DEFAULTS = { |
| 55 | + 'lr': lr, |
| 56 | + 'betas': betas, |
| 57 | + 'beta3': beta3, |
| 58 | + 'd': d0, |
| 59 | + 'd0': d0, |
| 60 | + 'd_max': d0, |
| 61 | + 'd_coef': d_coef, |
| 62 | + 'growth_rate': growth_rate, |
| 63 | + 'weight_decay': weight_decay, |
| 64 | + 'weight_decouple': weight_decouple, |
| 65 | + 'fixed_decay': fixed_decay, |
| 66 | + 'bias_correction': bias_correction, |
| 67 | + 'safeguard_warmup': safeguard_warmup, |
| 68 | + 'step': 1, |
| 69 | + 'eps': eps, |
| 70 | + } |
| 71 | + super().__init__(params, defaults) |
| 72 | + |
| 73 | + def __str__(self) -> str: |
| 74 | + return 'Prodigy' |
| 75 | + |
| 76 | + @torch.no_grad() |
| 77 | + def reset(self): |
| 78 | + for group in self.param_groups: |
| 79 | + group['step'] = 1 |
| 80 | + for p in group['params']: |
| 81 | + if p.grad is None: |
| 82 | + continue |
| 83 | + |
| 84 | + state = self.state[p] |
| 85 | + |
| 86 | + state['s'] = torch.zeros_like(p) |
| 87 | + state['exp_avg'] = torch.zeros_like(p) |
| 88 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 89 | + |
| 90 | + @torch.no_grad() |
| 91 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 92 | + loss: LOSS = None |
| 93 | + if closure is not None: |
| 94 | + with torch.enable_grad(): |
| 95 | + loss = closure() |
| 96 | + |
| 97 | + group = self.param_groups[0] |
| 98 | + device = group['params'][0].device |
| 99 | + |
| 100 | + d_de_nom = torch.tensor([0.0], device=device) |
| 101 | + |
| 102 | + beta1, beta2 = group['betas'] |
| 103 | + beta3 = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2) |
| 104 | + |
| 105 | + bias_correction1: float = 1.0 - beta1 ** group['step'] |
| 106 | + bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) |
| 107 | + bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0 |
| 108 | + |
| 109 | + d, d0 = group['d'], group['d0'] |
| 110 | + d_lr: float = d * group['lr'] / bias_correction |
| 111 | + |
| 112 | + if 'd_numerator' not in group: |
| 113 | + group['d_numerator'] = torch.tensor([0.0], device=device) |
| 114 | + |
| 115 | + d_numerator = group['d_numerator'] |
| 116 | + d_numerator.mul_(beta3) |
| 117 | + |
| 118 | + for group in self.param_groups: |
| 119 | + for p in group['params']: |
| 120 | + if p.grad is None: |
| 121 | + continue |
| 122 | + |
| 123 | + grad = p.grad |
| 124 | + if grad.is_sparse: |
| 125 | + raise NoSparseGradientError(str(self)) |
| 126 | + |
| 127 | + state = self.state[p] |
| 128 | + if len(state) == 0: |
| 129 | + state['s'] = torch.zeros_like(p) |
| 130 | + state['p0'] = p.clone() |
| 131 | + state['exp_avg'] = torch.zeros_like(p) |
| 132 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 133 | + |
| 134 | + p0, exp_avg, exp_avg_sq = state['p0'], state['exp_avg'], state['exp_avg_sq'] |
| 135 | + |
| 136 | + d_numerator.add_(torch.dot(grad.flatten(), (p0 - p).flatten()), alpha=(d / d0) * d_lr) |
| 137 | + |
| 138 | + exp_avg.mul_(beta1).add_(grad, alpha=d * (1.0 - beta1)) |
| 139 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2)) |
| 140 | + |
| 141 | + s = state['s'] |
| 142 | + s.mul_(beta3).add_(grad, alpha=(d / d0) * (d if group['safeguard_warmup'] else d_lr)) |
| 143 | + |
| 144 | + d_de_nom.add_(s.abs().sum()) |
| 145 | + |
| 146 | + if d_de_nom == 0: |
| 147 | + return loss |
| 148 | + |
| 149 | + d_hat = (group['d_coef'] * d_numerator / d_de_nom).item() |
| 150 | + if d == group['d0']: |
| 151 | + d = max(d, d_hat) |
| 152 | + |
| 153 | + d_max = max(group['d_max'], d_hat) |
| 154 | + d = min(d_max, d * group['growth_rate']) |
| 155 | + |
| 156 | + for group in self.param_groups: |
| 157 | + group['step'] += 1 |
| 158 | + |
| 159 | + group['d_numerator'] = d_numerator |
| 160 | + group['d_de_nom'] = d_de_nom |
| 161 | + group['d'] = d |
| 162 | + group['d_max'] = d_max |
| 163 | + group['d_hat'] = d_hat |
| 164 | + |
| 165 | + for p in group['params']: |
| 166 | + if p.grad is None: |
| 167 | + continue |
| 168 | + |
| 169 | + state = self.state[p] |
| 170 | + |
| 171 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 172 | + |
| 173 | + de_nom = exp_avg_sq.sqrt().add_(d * group['eps']) |
| 174 | + |
| 175 | + self.apply_weight_decay( |
| 176 | + p, |
| 177 | + p.grad, |
| 178 | + lr=d_lr, |
| 179 | + weight_decay=group['weight_decay'], |
| 180 | + weight_decouple=group['weight_decouple'], |
| 181 | + fixed_decay=group['fixed_decay'], |
| 182 | + ) |
| 183 | + |
| 184 | + p.addcdiv_(exp_avg, de_nom, value=-d_lr) |
| 185 | + |
| 186 | + return loss |
0 commit comments