|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 6 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 7 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 8 | + |
| 9 | + |
| 10 | +class CosineDecay: |
| 11 | + r"""Applies cosine decay to a parameter (death_rate), using PyTorch's built-in `CosineAnnealingLR`. |
| 12 | +
|
| 13 | + :param death_rate: float. initial value to be decayed. |
| 14 | + :param t_max: int. maximum number of iterations for the decay. |
| 15 | + :param eta_min: Optional[float]. minimum value of the parameter after decay. defaults to 0. |
| 16 | + :param last_epoch: Optional[int]. the index of the last epoch. Defaults to -1. |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, death_rate: float, t_max: int, eta_min: float = 0.0, last_epoch: int = -1): |
| 20 | + self.sgd = torch.optim.SGD( |
| 21 | + torch.nn.ParameterList([torch.nn.Parameter(torch.zeros(1))]), |
| 22 | + lr=death_rate, |
| 23 | + ) |
| 24 | + self.cosine_stepper = torch.optim.lr_scheduler.CosineAnnealingLR(self.sgd, t_max + 1, eta_min, last_epoch) |
| 25 | + self.T_max = t_max |
| 26 | + self.eta_min = eta_min |
| 27 | + |
| 28 | + def step(self, current_step: int) -> None: |
| 29 | + r"""One step of the cosine decay scheduler. |
| 30 | +
|
| 31 | + :param current_step: int. Current step index. |
| 32 | + """ |
| 33 | + self.cosine_stepper.step(current_step) |
| 34 | + |
| 35 | + def get_death_rate(self, current_step: int) -> float: |
| 36 | + r"""Get the updated rate (death_rate) at the given step. |
| 37 | +
|
| 38 | + :param current_step: int. Current step index. |
| 39 | + """ |
| 40 | + if current_step >= self.T_max: |
| 41 | + return self.eta_min |
| 42 | + |
| 43 | + self.step(current_step) |
| 44 | + |
| 45 | + return self.sgd.param_groups[0]['lr'] |
| 46 | + |
| 47 | + |
| 48 | +class SPAM(BaseOptimizer): |
| 49 | + r"""Spike-Aware Adam with Momentum Reset for Stable LLM Training. |
| 50 | +
|
| 51 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. |
| 52 | + :param lr: float. learning rate. |
| 53 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. |
| 54 | + :param density: float. density parameter. only used for 2d parameters (e.g. Linear). |
| 55 | + :param weight_decay: float. weight decay (L2 penalty). |
| 56 | + :param warmup_epoch: int: number of epochs to warm up. defaults to 50. |
| 57 | + :param threshold: int. threshold for gradient masking. defaults to 5000. |
| 58 | + :param grad_accu_steps: int. gradient accumulation steps before threshold-based masking applies. defaults to 20. |
| 59 | + :param update_proj_gap: int. update projection gap. |
| 60 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 61 | + """ |
| 62 | + |
| 63 | + def __init__( |
| 64 | + self, |
| 65 | + params: PARAMETERS, |
| 66 | + lr: float = 1e-3, |
| 67 | + betas: BETAS = (0.9, 0.999), |
| 68 | + density: float = 1.0, |
| 69 | + weight_decay: float = 0.0, |
| 70 | + warmup_epoch: int = 150, |
| 71 | + threshold: int = 5000, |
| 72 | + grad_accu_steps: int = 20, |
| 73 | + update_proj_gap: int = 500, |
| 74 | + eps: float = 1e-6, |
| 75 | + **kwargs, |
| 76 | + ): |
| 77 | + self.validate_learning_rate(lr) |
| 78 | + self.validate_betas(betas) |
| 79 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 80 | + self.validate_non_negative(warmup_epoch, 'warmup_epoch') |
| 81 | + self.validate_non_negative(density, 'density') |
| 82 | + self.validate_non_negative(threshold, 'threshold') |
| 83 | + self.validate_non_negative(grad_accu_steps, 'grad_accu_steps') |
| 84 | + self.validate_non_negative(update_proj_gap, 'update_proj_gap') |
| 85 | + self.validate_non_negative(eps, 'eps') |
| 86 | + |
| 87 | + self.density = density |
| 88 | + self.warmup_epoch = warmup_epoch |
| 89 | + self.threshold = threshold |
| 90 | + self.grad_accu_steps = grad_accu_steps |
| 91 | + self.update_proj_gap = update_proj_gap |
| 92 | + self.warmup = CosineDecay(0.99, warmup_epoch) |
| 93 | + |
| 94 | + defaults: DEFAULTS = { |
| 95 | + 'lr': lr, |
| 96 | + 'betas': betas, |
| 97 | + 'weight_decay': weight_decay, |
| 98 | + 'eps': eps, |
| 99 | + **kwargs, |
| 100 | + } |
| 101 | + super().__init__(params, defaults) |
| 102 | + |
| 103 | + self.init_masks() |
| 104 | + |
| 105 | + self.state['total_step'] = 0 |
| 106 | + self.state['current_step'] = warmup_epoch + 1 |
| 107 | + |
| 108 | + @staticmethod |
| 109 | + def initialize_random_rank_boolean_tensor(m: int, n: int, density: float) -> torch.Tensor: |
| 110 | + r"""Create an (m x n) boolean tensor with `density` fraction of True entries. |
| 111 | +
|
| 112 | + :param m: int. number of rows. |
| 113 | + :param n: int. number of columns. |
| 114 | + :param density: float. fraction of True entries. 1.0 means all True. |
| 115 | + """ |
| 116 | + total_elements: int = m * n |
| 117 | + non_zero_count: int = int(density * total_elements) |
| 118 | + |
| 119 | + tensor = torch.zeros((m, n), dtype=torch.bool) |
| 120 | + |
| 121 | + if non_zero_count == 0: |
| 122 | + return tensor |
| 123 | + |
| 124 | + indices = torch.randperm(total_elements)[:non_zero_count] |
| 125 | + rows, cols = indices // n, indices % n |
| 126 | + tensor[rows, cols] = True |
| 127 | + |
| 128 | + return tensor |
| 129 | + |
| 130 | + def update_mask_random(self, density: float, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor: |
| 131 | + r"""Update a random mask. |
| 132 | +
|
| 133 | + Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for |
| 134 | + the overlap region. |
| 135 | +
|
| 136 | + :param density: float. fraction of elements to keep. |
| 137 | + :param p: torch.Tensor. parameter to which the mask is applied. |
| 138 | + :param old_mask: torch.Tensor. previous binary mask. |
| 139 | + """ |
| 140 | + new_mask: torch.Tensor = torch.rand_like(p) < density |
| 141 | + |
| 142 | + exp_avg = torch.zeros_like(p[new_mask]) |
| 143 | + exp_avg_sq = torch.zeros_like(p[new_mask]) |
| 144 | + |
| 145 | + intersection_mask = new_mask & old_mask |
| 146 | + new_intersection_indices = intersection_mask[new_mask] |
| 147 | + old_intersection_indices = intersection_mask[old_mask] |
| 148 | + |
| 149 | + state = self.state[p] |
| 150 | + exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices] |
| 151 | + exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices] |
| 152 | + |
| 153 | + state['exp_avg'] = exp_avg |
| 154 | + state['exp_avg_sq'] = exp_avg_sq |
| 155 | + |
| 156 | + return new_mask |
| 157 | + |
| 158 | + def update_masks(self) -> None: |
| 159 | + r"""Update masks in each parameter group that has 'density'. |
| 160 | +
|
| 161 | + The new mask is selected randomly, and the overlap ratio with the old mask is printed. |
| 162 | + """ |
| 163 | + for group in self.param_groups: |
| 164 | + for p in group['params']: |
| 165 | + state = self.state[p] |
| 166 | + if 'mask' in state: |
| 167 | + new_mask = self.update_mask_random(self.density, p, state['mask']) |
| 168 | + state['mask'] = new_mask |
| 169 | + p.mask = new_mask |
| 170 | + |
| 171 | + def init_masks(self) -> None: |
| 172 | + r"""Initialize random masks for each parameter group that has 'density'.""" |
| 173 | + for group in self.param_groups: |
| 174 | + for p in group['params']: |
| 175 | + state = self.state[p] |
| 176 | + if p.dim() == 2 and 'mask' not in state: |
| 177 | + state['mask'] = self.initialize_random_rank_boolean_tensor( |
| 178 | + p.shape[0], |
| 179 | + p.shape[1], |
| 180 | + density=self.density, |
| 181 | + ).to(p.device) |
| 182 | + |
| 183 | + def __str__(self) -> str: |
| 184 | + return 'SPAM' |
| 185 | + |
| 186 | + @torch.no_grad() |
| 187 | + def reset(self): |
| 188 | + for group in self.param_groups: |
| 189 | + group['step'] = 0 |
| 190 | + for p in group['params']: |
| 191 | + state = self.state[p] |
| 192 | + |
| 193 | + state['exp_avg'] = torch.zeros_like(p) |
| 194 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 195 | + |
| 196 | + @torch.no_grad() |
| 197 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 198 | + loss: LOSS = None |
| 199 | + if closure is not None: |
| 200 | + with torch.enable_grad(): |
| 201 | + loss = closure() |
| 202 | + |
| 203 | + scale_factor: float = 1.0 - self.warmup.get_death_rate(self.state['current_step']) |
| 204 | + |
| 205 | + for group in self.param_groups: |
| 206 | + if 'step' not in group: |
| 207 | + group['step'] = 1 |
| 208 | + else: |
| 209 | + group['step'] += 1 |
| 210 | + |
| 211 | + beta1, beta2 = group['betas'] |
| 212 | + |
| 213 | + bias_correction1: float = self.debias(beta1, group['step']) |
| 214 | + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) |
| 215 | + |
| 216 | + step_size: float = group['lr'] * bias_correction2_sq / bias_correction1 |
| 217 | + |
| 218 | + for p in group['params']: |
| 219 | + if p.grad is None: |
| 220 | + continue |
| 221 | + |
| 222 | + grad = p.grad |
| 223 | + if grad.is_sparse: |
| 224 | + raise NoSparseGradientError(str(self)) |
| 225 | + |
| 226 | + state = self.state[p] |
| 227 | + |
| 228 | + if 'mask' in state: |
| 229 | + grad = grad[state['mask']] |
| 230 | + |
| 231 | + if len(state) == 0: |
| 232 | + state['exp_avg'] = torch.zeros_like(grad) |
| 233 | + state['exp_avg_sq'] = torch.zeros_like(grad) |
| 234 | + |
| 235 | + if (self.state['total_step'] + 1) % self.update_proj_gap == 0: |
| 236 | + state['exp_avg'] = torch.zeros_like(grad) |
| 237 | + state['exp_avg_sq'] = torch.zeros_like(grad) |
| 238 | + |
| 239 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 240 | + |
| 241 | + if self.threshold != 0: |
| 242 | + current_step: int = self.state['total_step'] + 1 |
| 243 | + if current_step >= self.grad_accu_steps and ( |
| 244 | + self.update_proj_gap == 0 or current_step % self.update_proj_gap >= self.grad_accu_steps |
| 245 | + ): |
| 246 | + mask = grad.pow(2) > (self.threshold * exp_avg_sq) |
| 247 | + grad[mask].sign_().mul_(torch.sqrt(exp_avg_sq[mask] * self.threshold)) |
| 248 | + |
| 249 | + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) |
| 250 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
| 251 | + |
| 252 | + de_nom = exp_avg_sq.sqrt().add_(group['eps']) |
| 253 | + |
| 254 | + if 'mask' in state: |
| 255 | + grad_full = torch.zeros_like(p.grad) |
| 256 | + grad_full[state['mask']] = exp_avg / de_nom |
| 257 | + p.add_(grad_full, alpha=-step_size * scale_factor) |
| 258 | + else: |
| 259 | + p.addcdiv_(exp_avg, de_nom, value=-step_size * scale_factor) |
| 260 | + |
| 261 | + if group['weight_decay'] > 0: |
| 262 | + if 'mask' in state: |
| 263 | + p[state['mask']].add_(p[state['mask']], alpha=-group['lr'] * group['weight_decay']) |
| 264 | + else: |
| 265 | + p.add_(p, alpha=-group['lr'] * group['weight_decay']) |
| 266 | + |
| 267 | + self.state['total_step'] += 1 |
| 268 | + self.state['current_step'] += 1 |
| 269 | + |
| 270 | + if (self.state['total_step'] != 0) and (self.state['total_step'] + 1) % self.update_proj_gap == 0: |
| 271 | + self.update_masks() |
| 272 | + self.state['current_step'] = 0 |
| 273 | + self.warmup = CosineDecay(0.99, self.warmup_epoch) |
| 274 | + |
| 275 | + return loss |
0 commit comments