|
| 1 | +import math |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import distributed as dist |
| 6 | +from torch import nn |
| 7 | +from torch.optim.optimizer import Optimizer |
| 8 | + |
| 9 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 10 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 11 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS |
| 12 | + |
| 13 | + |
| 14 | +class AdamMini(Optimizer, BaseOptimizer): |
| 15 | + r"""Use Fewer Learning Rates To Gain More. |
| 16 | +
|
| 17 | + :param model: nn.Module. model instance. |
| 18 | + :param model_sharding: bool. set to True if you are using model parallelism with more than 1 GPU, including FSDP |
| 19 | + and zero_1, 2, 3 in Deepspeed. Set to False if otherwise. |
| 20 | + :param lr: float. learning rate. |
| 21 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. |
| 22 | + :param weight_decay: float. weight decay (L2 penalty). |
| 23 | + :param num_embeds: int. number of embedding dimensions. could be unspecified if you are training non-transformer |
| 24 | + models. |
| 25 | + :param num_heads: int. number of attention heads. could be unspecified if you are training non-transformer models. |
| 26 | + :param num_query_groups: Optional[int]. number of query groups in Group Query Attention (GQA). if not specified, it |
| 27 | + will be equal to num_heads. could be unspecified if you are training non-transformer models. |
| 28 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + model: nn.Module, |
| 34 | + lr: float = 1.0, |
| 35 | + betas: BETAS = (0.9, 0.999), |
| 36 | + weight_decay: float = 0.1, |
| 37 | + model_sharding: bool = False, |
| 38 | + num_embeds: int = 2048, |
| 39 | + num_heads: int = 32, |
| 40 | + num_query_groups: Optional[int] = None, |
| 41 | + eps: float = 1e-8, |
| 42 | + ): |
| 43 | + self.validate_learning_rate(lr) |
| 44 | + self.validate_betas(betas) |
| 45 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 46 | + self.validate_non_negative(num_embeds, 'num_embeds') |
| 47 | + self.validate_non_negative(num_heads, 'num_heads') |
| 48 | + self.validate_non_negative(eps, 'eps') |
| 49 | + |
| 50 | + self.num_query_groups: int = num_query_groups if num_query_groups is not None else num_embeds |
| 51 | + self.validate_mod(num_embeds, self.num_query_groups) |
| 52 | + |
| 53 | + self.world_size: int = torch.cuda.device_count() |
| 54 | + |
| 55 | + self.model = model |
| 56 | + self.model_sharding = model_sharding |
| 57 | + self.num_embeds = num_embeds |
| 58 | + self.num_heads = num_heads |
| 59 | + |
| 60 | + groups = self.get_optimizer_groups(weight_decay) |
| 61 | + |
| 62 | + defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'eps': eps} |
| 63 | + super().__init__(groups, defaults) |
| 64 | + |
| 65 | + def __str__(self) -> str: |
| 66 | + return 'AdamMini' |
| 67 | + |
| 68 | + def get_optimizer_groups(self, weight_decay: float): |
| 69 | + groups = [] |
| 70 | + for name, param in self.model.named_parameters(): |
| 71 | + if not param.requires_grad: |
| 72 | + continue |
| 73 | + |
| 74 | + group = { |
| 75 | + 'name': name, |
| 76 | + 'params': param, |
| 77 | + 'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay, |
| 78 | + } |
| 79 | + |
| 80 | + if ( |
| 81 | + 'self_attn.k_proj.weight' in name |
| 82 | + or 'self_attn.q_proj.weight' in name |
| 83 | + or 'attn.wq.weight' in name |
| 84 | + or 'attn.wk.weight' in name |
| 85 | + ): |
| 86 | + group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads |
| 87 | + |
| 88 | + if 'attn.attn.weight' in name or 'attn.qkv.weight' in name: |
| 89 | + group['n_head'] = self.num_heads |
| 90 | + group['q_per_kv'] = self.num_embeds // self.num_query_groups |
| 91 | + |
| 92 | + groups.append(group) |
| 93 | + |
| 94 | + return groups |
| 95 | + |
| 96 | + @torch.no_grad() |
| 97 | + def reset(self): |
| 98 | + for group in self.param_groups: |
| 99 | + group['step'] = 0 |
| 100 | + for p in group['params']: |
| 101 | + state = self.state[p] |
| 102 | + |
| 103 | + state['m'] = torch.zeros_like(p, dtype=torch.float32) |
| 104 | + state['v'] = torch.zeros_like(p, dtype=torch.float32) |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def step_embed( |
| 108 | + p, |
| 109 | + grad, |
| 110 | + state, |
| 111 | + lr: float, |
| 112 | + beta1: float, |
| 113 | + beta2: float, |
| 114 | + bias_correction1: float, |
| 115 | + bias_correction2_sq: float, |
| 116 | + eps: float, |
| 117 | + ) -> None: |
| 118 | + if len(state) == 0: |
| 119 | + state['m'] = torch.zeros_like(p, dtype=torch.float32) |
| 120 | + state['v'] = torch.zeros_like(p, dtype=torch.float32) |
| 121 | + |
| 122 | + m, v = state['m'], state['v'] |
| 123 | + |
| 124 | + m.lerp_(grad, weight=1.0 - beta1) |
| 125 | + v.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2) |
| 126 | + |
| 127 | + h = (v.sqrt() / bias_correction2_sq).add_(eps) |
| 128 | + |
| 129 | + p.addcdiv_(m, h, value=-lr / bias_correction1) |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def step_attn_proj( |
| 133 | + p, |
| 134 | + grad, |
| 135 | + state, |
| 136 | + parameter_per_head: int, |
| 137 | + lr: float, |
| 138 | + beta1: float, |
| 139 | + beta2: float, |
| 140 | + bias_correction1: float, |
| 141 | + bias_correction2_sq: float, |
| 142 | + eps: float, |
| 143 | + ) -> None: |
| 144 | + if len(state) == 0: |
| 145 | + state['m'] = torch.zeros_like(p, dtype=torch.float32).view(-1, parameter_per_head) |
| 146 | + state['head'] = state['m'].shape[0] |
| 147 | + state['v_mean'] = torch.zeros(state['head'], device=state['m'].device) |
| 148 | + |
| 149 | + m, v = state['m'], state['v_mean'] |
| 150 | + |
| 151 | + head: int = state['head'] |
| 152 | + grad = grad.view(head, parameter_per_head) |
| 153 | + |
| 154 | + m.lerp_(grad, weight=1.0 - beta1) |
| 155 | + |
| 156 | + tmp_lr = torch.mean(grad * grad, dim=1).to(m.device) |
| 157 | + v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2) |
| 158 | + |
| 159 | + h = (v.sqrt() / bias_correction2_sq).add_(eps) |
| 160 | + |
| 161 | + update = (1 / (h * bias_correction1)).view(head, 1).mul_(m) |
| 162 | + |
| 163 | + if p.dim() > 1: |
| 164 | + d0, d1 = p.size() |
| 165 | + update = update.view(d0, d1) |
| 166 | + else: |
| 167 | + update = update.view(-1) |
| 168 | + |
| 169 | + p.add_(update, alpha=-lr) |
| 170 | + |
| 171 | + @staticmethod |
| 172 | + def step_attn( |
| 173 | + p, |
| 174 | + grad, |
| 175 | + state, |
| 176 | + num_heads: int, |
| 177 | + q_per_kv: int, |
| 178 | + lr: float, |
| 179 | + beta1: float, |
| 180 | + beta2: float, |
| 181 | + bias_correction1: float, |
| 182 | + bias_correction2_sq: float, |
| 183 | + eps: float, |
| 184 | + ) -> None: |
| 185 | + if len(state) == 0: |
| 186 | + state['m'] = torch.zeros_like(p, dtype=torch.float32).view(num_heads, q_per_kv + 2, -1) |
| 187 | + state['v_mean'] = torch.zeros(num_heads, q_per_kv + 2, device=state['m'].device) |
| 188 | + |
| 189 | + m, v = state['m'], state['v_mean'] |
| 190 | + |
| 191 | + grad = grad.view(num_heads, q_per_kv + 2, -1) |
| 192 | + |
| 193 | + m.lerp_(grad, weight=1.0 - beta1) |
| 194 | + |
| 195 | + tmp_lr = torch.mean(grad * grad, dim=2).to(m.device) |
| 196 | + v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2) |
| 197 | + |
| 198 | + h = (v.sqrt() / bias_correction2_sq).add_(eps) |
| 199 | + |
| 200 | + update = (1 / (h * bias_correction1)).view(num_heads, q_per_kv + 2, -1).mul_(m) |
| 201 | + |
| 202 | + if p.dim() > 1: |
| 203 | + d0, d1 = p.size() |
| 204 | + update = update.view(d0, d1) |
| 205 | + else: |
| 206 | + update = update.view(-1) |
| 207 | + |
| 208 | + p.add_(update, alpha=-lr) |
| 209 | + |
| 210 | + def step_lefts( |
| 211 | + self, |
| 212 | + p, |
| 213 | + grad, |
| 214 | + state, |
| 215 | + lr: float, |
| 216 | + beta1: float, |
| 217 | + beta2: float, |
| 218 | + bias_correction1: float, |
| 219 | + bias_correction2_sq: float, |
| 220 | + eps: float, |
| 221 | + ) -> None: # pragma: no cover |
| 222 | + if len(state) == 0: |
| 223 | + dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32) |
| 224 | + |
| 225 | + reduced: bool = False |
| 226 | + if self.model_sharding and self.world_size > 1: |
| 227 | + tensor_list = [torch.zeros_like(dim) for _ in range(self.world_size)] |
| 228 | + dist.all_gather(tensor_list, dim) |
| 229 | + |
| 230 | + s, dim = 0, 0 |
| 231 | + for d in tensor_list: |
| 232 | + if d > 0: |
| 233 | + s += 1 |
| 234 | + dim += d |
| 235 | + |
| 236 | + if s >= 2: |
| 237 | + reduced = True |
| 238 | + |
| 239 | + state['m'] = torch.zeros_like(p, dtype=torch.float32) |
| 240 | + state['v_mean'] = torch.tensor(0.0, device=state['m'].device) |
| 241 | + state['dimension'] = dim |
| 242 | + state['reduced'] = reduced |
| 243 | + |
| 244 | + tmp_lr = torch.sum(grad * grad) |
| 245 | + |
| 246 | + if state['reduced']: |
| 247 | + dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM) |
| 248 | + |
| 249 | + tmp_lr.div_(state['dim']) |
| 250 | + |
| 251 | + m, v = state['m'], state['v_mean'] |
| 252 | + |
| 253 | + m.lerp_(grad, weight=1.0 - beta1) |
| 254 | + v.mul_(beta2).add_(tmp_lr, value=1.0 - beta2) |
| 255 | + |
| 256 | + h = (v.sqrt() / bias_correction2_sq).add_(eps) |
| 257 | + |
| 258 | + update = 1 / (bias_correction1 * h).mul_(m) |
| 259 | + |
| 260 | + p.add_(update, alpha=-lr) |
| 261 | + |
| 262 | + @torch.no_grad() |
| 263 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 264 | + loss: LOSS = None |
| 265 | + if closure is not None: |
| 266 | + with torch.enable_grad(): |
| 267 | + loss = closure() |
| 268 | + |
| 269 | + for group in self.param_groups: |
| 270 | + if 'step' in group: |
| 271 | + group['step'] += 1 |
| 272 | + else: |
| 273 | + group['step'] = 1 |
| 274 | + |
| 275 | + name = group['name'] |
| 276 | + |
| 277 | + beta1, beta2 = group['betas'] |
| 278 | + |
| 279 | + bias_correction1: float = 1.0 - beta1 ** group['step'] |
| 280 | + bias_correction2: float = 1.0 - beta2 ** group['step'] |
| 281 | + bias_correction2_sq: float = math.sqrt(bias_correction2) |
| 282 | + |
| 283 | + for p in group['params']: |
| 284 | + if p.grad is None: |
| 285 | + continue |
| 286 | + |
| 287 | + grad = p.grad |
| 288 | + if grad.is_sparse: |
| 289 | + raise NoSparseGradientError(str(self)) |
| 290 | + |
| 291 | + grad = grad.to(torch.float32) |
| 292 | + |
| 293 | + state = self.state[p] |
| 294 | + |
| 295 | + self.apply_weight_decay( |
| 296 | + p=p, |
| 297 | + grad=grad, |
| 298 | + lr=group['lr'], |
| 299 | + weight_decay=group['weight_decay'], |
| 300 | + weight_decouple=True, |
| 301 | + fixed_decay=False, |
| 302 | + ) |
| 303 | + |
| 304 | + if 'embed_tokens' in name or 'wte' in name or 'lm_head' in name: |
| 305 | + self.step_embed( |
| 306 | + p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps'] |
| 307 | + ) |
| 308 | + elif ( |
| 309 | + 'self_attn.k_proj.weight' in name |
| 310 | + or 'self_attn.q_proj.weight' in name |
| 311 | + or 'attn.wq.weight' in name |
| 312 | + or 'attn.wk.weight' in name |
| 313 | + ): |
| 314 | + self.step_attn_proj( |
| 315 | + p, |
| 316 | + grad, |
| 317 | + state, |
| 318 | + group['parameter_per_head'], |
| 319 | + group['lr'], |
| 320 | + beta1, |
| 321 | + beta2, |
| 322 | + bias_correction1, |
| 323 | + bias_correction2_sq, |
| 324 | + group['eps'], |
| 325 | + ) |
| 326 | + elif 'attn.attn.weight' in name or 'attn.qkv.weight' in name: |
| 327 | + self.step_attn( |
| 328 | + p, |
| 329 | + grad, |
| 330 | + state, |
| 331 | + group['n_head'], |
| 332 | + group['q_per_kv'], |
| 333 | + group['lr'], |
| 334 | + beta1, |
| 335 | + beta2, |
| 336 | + bias_correction1, |
| 337 | + bias_correction2_sq, |
| 338 | + group['eps'], |
| 339 | + ) |
| 340 | + else: |
| 341 | + self.step_lefts( |
| 342 | + p, |
| 343 | + grad, |
| 344 | + state, |
| 345 | + group['lr'], |
| 346 | + beta1, |
| 347 | + beta2, |
| 348 | + bias_correction1, |
| 349 | + bias_correction2_sq, |
| 350 | + group['eps'], |
| 351 | + ) |
| 352 | + |
| 353 | + return loss |
0 commit comments