|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.optim import Optimizer |
| 5 | + |
| 6 | +from pytorch_optimizer.types import CLOSURE, DEFAULT_PARAMETERS, LOSS |
| 7 | + |
| 8 | + |
| 9 | +class MADGRAD(Optimizer): |
| 10 | + """ |
| 11 | + A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic |
| 12 | + Reference : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py |
| 13 | + """ |
| 14 | + |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + params, |
| 18 | + lr: float = 1e-3, |
| 19 | + momentum: float = 0.9, |
| 20 | + weight_decay: float = 0.0, |
| 21 | + eps: float = 1e-6, |
| 22 | + ): |
| 23 | + self.lr = lr |
| 24 | + self.momentum = momentum |
| 25 | + self.weight_decay = weight_decay |
| 26 | + self.eps = eps |
| 27 | + |
| 28 | + self.check_valid_parameters() |
| 29 | + |
| 30 | + defaults: DEFAULT_PARAMETERS = dict( |
| 31 | + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay |
| 32 | + ) |
| 33 | + super().__init__(params, defaults) |
| 34 | + |
| 35 | + def check_valid_parameters(self): |
| 36 | + if 0.0 > self.lr: |
| 37 | + raise ValueError(f'Invalid learning rate : {self.lr}') |
| 38 | + if 0.0 > self.eps: |
| 39 | + raise ValueError(f'Invalid eps : {self.eps}') |
| 40 | + if 0.0 > self.weight_decay: |
| 41 | + raise ValueError(f'Invalid weight_decay : {self.weight_decay}') |
| 42 | + if 0.0 > self.momentum or 1.0 <= self.momentum: |
| 43 | + raise ValueError(f'Invalid momentum : {self.momentum}') |
| 44 | + |
| 45 | + @property |
| 46 | + def supports_memory_efficient_fp16(self) -> bool: |
| 47 | + return False |
| 48 | + |
| 49 | + @property |
| 50 | + def supports_flat_params(self) -> bool: |
| 51 | + return True |
| 52 | + |
| 53 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 54 | + """Performs a single optimization step. |
| 55 | + Arguments: |
| 56 | + closure (callable, optional): A closure that reevaluates the model |
| 57 | + and returns the loss. |
| 58 | + """ |
| 59 | + loss: LOSS = None |
| 60 | + if closure is not None: |
| 61 | + loss = closure() |
| 62 | + |
| 63 | + # step counter must be stored in state to ensure correct behavior under |
| 64 | + # optimizer sharding |
| 65 | + if 'k' not in self.state: |
| 66 | + self.state['k'] = torch.tensor([0], dtype=torch.long) |
| 67 | + |
| 68 | + k = self.state['k'].item() |
| 69 | + |
| 70 | + for group in self.param_groups: |
| 71 | + eps = group['eps'] |
| 72 | + lr = group['lr'] + eps |
| 73 | + decay = group['weight_decay'] |
| 74 | + momentum = group['momentum'] |
| 75 | + |
| 76 | + ck: float = 1.0 - momentum |
| 77 | + _lambda = lr * math.pow(k + 1, 0.5) |
| 78 | + |
| 79 | + for p in group['params']: |
| 80 | + if p.grad is None: |
| 81 | + continue |
| 82 | + |
| 83 | + grad = p.grad.data |
| 84 | + state = self.state[p] |
| 85 | + |
| 86 | + if 'grad_sum_sq' not in state: |
| 87 | + state['grad_sum_sq'] = torch.zeros_like(p.data).detach() |
| 88 | + state['s'] = torch.zeros_like(p.data).detach() |
| 89 | + if momentum != 0: |
| 90 | + state['x0'] = torch.clone(p.data).detach() |
| 91 | + |
| 92 | + if momentum != 0.0 and grad.is_sparse: |
| 93 | + raise RuntimeError( |
| 94 | + 'momentum != 0 is not compatible with sparse gradients' |
| 95 | + ) |
| 96 | + |
| 97 | + grad_sum_sq = state['grad_sum_sq'] |
| 98 | + s = state['s'] |
| 99 | + |
| 100 | + if decay != 0: |
| 101 | + if grad.is_sparse: |
| 102 | + raise RuntimeError( |
| 103 | + 'weight_decay option is not compatible with sparse gradients' |
| 104 | + ) |
| 105 | + |
| 106 | + grad.add_(p.data, alpha=decay) |
| 107 | + |
| 108 | + if grad.is_sparse: |
| 109 | + grad = grad.coalesce() |
| 110 | + grad_val = grad._values() |
| 111 | + |
| 112 | + p_masked = p.sparse_mask(grad) |
| 113 | + grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) |
| 114 | + s_masked = s.sparse_mask(grad) |
| 115 | + |
| 116 | + # Compute x_0 from other known quantities |
| 117 | + rms_masked_vals = ( |
| 118 | + grad_sum_sq_masked._values().pow(1 / 3).add_(eps) |
| 119 | + ) |
| 120 | + x0_masked_vals = p_masked._values().addcdiv( |
| 121 | + s_masked._values(), rms_masked_vals, value=1 |
| 122 | + ) |
| 123 | + |
| 124 | + # Dense + sparse op |
| 125 | + grad_sq = grad * grad |
| 126 | + grad_sum_sq.add_(grad_sq, alpha=_lambda) |
| 127 | + grad_sum_sq_masked.add_(grad_sq, alpha=_lambda) |
| 128 | + |
| 129 | + rms_masked_vals = ( |
| 130 | + grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) |
| 131 | + ) |
| 132 | + |
| 133 | + s.add_(grad, alpha=_lambda) |
| 134 | + s_masked._values().add_(grad_val, alpha=_lambda) |
| 135 | + |
| 136 | + # update masked copy of p |
| 137 | + p_kp1_masked_values = x0_masked_vals.addcdiv( |
| 138 | + s_masked._values(), rms_masked_vals, value=-1 |
| 139 | + ) |
| 140 | + |
| 141 | + # Copy updated masked p to dense p using an add operation |
| 142 | + p_masked._values().add_(p_kp1_masked_values, alpha=-1) |
| 143 | + p.data.add_(p_masked, alpha=-1) |
| 144 | + else: |
| 145 | + if momentum == 0: |
| 146 | + # Compute x_0 from other known quantities |
| 147 | + rms = grad_sum_sq.pow(1 / 3).add_(eps) |
| 148 | + x0 = p.data.addcdiv(s, rms, value=1) |
| 149 | + else: |
| 150 | + x0 = state['x0'] |
| 151 | + |
| 152 | + # Accumulate second moments |
| 153 | + grad_sum_sq.addcmul_(grad, grad, value=_lambda) |
| 154 | + rms = grad_sum_sq.pow(1 / 3).add_(eps) |
| 155 | + |
| 156 | + # Update s |
| 157 | + s.data.add_(grad, alpha=_lambda) |
| 158 | + |
| 159 | + # Step |
| 160 | + if momentum == 0: |
| 161 | + p.data.copy_(x0.addcdiv(s, rms, value=-1)) |
| 162 | + else: |
| 163 | + z = x0.addcdiv(s, rms, value=-1) |
| 164 | + |
| 165 | + # p is a moving average of z |
| 166 | + p.data.mul_(1 - ck).add_(z, alpha=ck) |
| 167 | + |
| 168 | + self.state['k'] += 1 |
| 169 | + |
| 170 | + return loss |
0 commit comments