|
| 1 | +import os |
| 2 | +from typing import List, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.distributed import ReduceOp, all_reduce |
| 6 | + |
| 7 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 8 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 9 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 10 | + |
| 11 | + |
| 12 | +def zero_power_via_newton_schulz_5( |
| 13 | + g: torch.Tensor, num_steps: int = 10, eps: float = 1e-7, weights: Tuple[int, int, int] = (3.4445, -4.7750, 2.0315) |
| 14 | +) -> torch.Tensor: |
| 15 | + r"""Compute the zeroth power / orthogonalization of G. |
| 16 | +
|
| 17 | + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a quintic iteration |
| 18 | + whose coefficients are selected to maximize the slope at zero. For the purpose of minimizing steps, it turns out |
| 19 | + to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no |
| 20 | + longer converges all the way to one everywhere on the interval. This iteration therefore does not produce UV^T but |
| 21 | + rather something like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt |
| 22 | + model performance at all relative to UV^T, where USV^T = G is the SVD. |
| 23 | +
|
| 24 | + :param g: torch.Tensor. matrix. |
| 25 | + :param num_steps: int. number of iterations. |
| 26 | + :param eps: float. add this times I to G, to make is positive definite. For scaling, we multiply it by the largest |
| 27 | + eigenvalue of G. |
| 28 | + :param weights: Tuple[int, int, int]. weights. |
| 29 | + """ |
| 30 | + if len(g.shape) != 2: |
| 31 | + raise ValueError('shape of g must be 2-dimensional') |
| 32 | + |
| 33 | + x = g.bfloat16() |
| 34 | + x.div_(x.norm().add_(eps)) |
| 35 | + |
| 36 | + if g.size(0) > g.size(1): |
| 37 | + x = x.T |
| 38 | + |
| 39 | + for _ in range(num_steps): |
| 40 | + a = x @ x.T |
| 41 | + b = weights[1] * a + weights[2] * a @ a |
| 42 | + x = weights[0] * x + b @ x |
| 43 | + |
| 44 | + if g.size(0) > g.size(1): |
| 45 | + x = x.T |
| 46 | + |
| 47 | + return x |
| 48 | + |
| 49 | + |
| 50 | +class Muon(BaseOptimizer): |
| 51 | + r"""MomentUm Orthogonalized by Newton-schulz. |
| 52 | +
|
| 53 | + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which |
| 54 | + each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each |
| 55 | + update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. |
| 56 | +
|
| 57 | + Some warnings: |
| 58 | + - We believe this optimizer is unlikely to work well for training with small batch size. |
| 59 | + - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this. |
| 60 | +
|
| 61 | + :param params: PARAMETERS. the parameters to be optimized by Muon. |
| 62 | + :param lr: float. learning rate. |
| 63 | + :param momentum: float. the momentum used by the internal SGD. |
| 64 | + :param betas: The betas for the internal AdamW. |
| 65 | + :param nesterov: bool. whether to use nesterov momentum. |
| 66 | + :param ns_steps: int. the number of Newton-Schulz iterations to run. (6 is probably always enough) |
| 67 | + :param adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are {0, 1}-D or |
| 68 | + are detected as being the embed or lm_head will be optimized by AdamW as well. |
| 69 | + :param adamw_lr: The learning rate for the internal AdamW. |
| 70 | + :param adamw_wd: The weight decay for the internal AdamW. |
| 71 | + :param adamw_eps: The epsilon for the internal AdamW. |
| 72 | + """ |
| 73 | + |
| 74 | + def __init__( |
| 75 | + self, |
| 76 | + params: PARAMETERS, |
| 77 | + lr: float = 2e-2, |
| 78 | + momentum: float = 0.95, |
| 79 | + betas: BETAS = (0.95, 0.95), |
| 80 | + nesterov: bool = True, |
| 81 | + ns_steps: int = 6, |
| 82 | + adamw_params: Optional[PARAMETERS] = None, |
| 83 | + adamw_lr: float = 3e-4, |
| 84 | + adamw_wd: float = 0, |
| 85 | + adamw_eps: float = 1e-8, |
| 86 | + **kwargs, |
| 87 | + ): |
| 88 | + self.validate_learning_rate(lr) |
| 89 | + self.validate_learning_rate(adamw_lr) |
| 90 | + self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)') |
| 91 | + self.validate_positive(ns_steps, 'ns_steps') |
| 92 | + self.validate_betas(betas) |
| 93 | + self.validate_non_negative(adamw_wd, 'adamw_wd') |
| 94 | + self.validate_non_negative(adamw_eps, 'adamw_eps') |
| 95 | + |
| 96 | + params = self.get_parameters(params) |
| 97 | + adamw_params = self.get_parameters(adamw_params) if adamw_params is not None else [] |
| 98 | + params.extend(adamw_params) |
| 99 | + |
| 100 | + self.world_size: int = int(os.environ.get('WORLD_SIZE', 1)) |
| 101 | + self.rank: int = int(os.environ.get('RANK', 0)) |
| 102 | + |
| 103 | + defaults: DEFAULTS = { |
| 104 | + 'lr': lr, |
| 105 | + 'momentum': momentum, |
| 106 | + 'nesterov': nesterov, |
| 107 | + 'ns_steps': ns_steps, |
| 108 | + 'adamw_lr': adamw_lr, |
| 109 | + 'adamw_lr_ratio': adamw_lr / lr, |
| 110 | + 'adamw_betas': betas, |
| 111 | + 'adamw_wd': adamw_wd, |
| 112 | + 'adamw_eps': adamw_eps, |
| 113 | + } |
| 114 | + super().__init__(params, defaults) |
| 115 | + |
| 116 | + self.set_muon_state(params, adamw_params) |
| 117 | + |
| 118 | + def __str__(self) -> str: |
| 119 | + return 'Muon' |
| 120 | + |
| 121 | + @staticmethod |
| 122 | + def get_parameters(params: PARAMETERS) -> List[torch.Tensor]: |
| 123 | + if isinstance(params, list) and isinstance(params[0], torch.Tensor): |
| 124 | + return params |
| 125 | + |
| 126 | + new_params = [] |
| 127 | + for group in params: |
| 128 | + if isinstance(group, dict) and 'params' in group: |
| 129 | + new_params.extend(list(group['params'])) |
| 130 | + else: |
| 131 | + new_params.append(group) |
| 132 | + |
| 133 | + return new_params |
| 134 | + |
| 135 | + def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS, threshold: int = 8192) -> None: |
| 136 | + r"""Set use_muon flag.""" |
| 137 | + for p in params: |
| 138 | + self.state[p]['use_muon'] = p.ndim >= 2 and p.size(0) < threshold |
| 139 | + |
| 140 | + for p in adamw_params: |
| 141 | + self.state[p]['use_muon'] = False |
| 142 | + |
| 143 | + @torch.no_grad() |
| 144 | + def reset(self): |
| 145 | + for group in self.param_groups: |
| 146 | + group['step'] = 0 |
| 147 | + for p in group['params']: |
| 148 | + state = self.state[p] |
| 149 | + |
| 150 | + state['momentum_buffer'] = torch.zeros_like(p) |
| 151 | + state['moment1'] = torch.zeros_like(p) |
| 152 | + state['moment2'] = torch.zeros_like(p) |
| 153 | + |
| 154 | + @torch.no_grad() |
| 155 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 156 | + loss: LOSS = None |
| 157 | + if closure is not None: |
| 158 | + with torch.enable_grad(): |
| 159 | + loss = closure() |
| 160 | + |
| 161 | + for group in self.param_groups: |
| 162 | + if 'step' in group: |
| 163 | + group['step'] += 1 |
| 164 | + else: |
| 165 | + group['step'] = 1 |
| 166 | + |
| 167 | + params = [] |
| 168 | + for p in group['params']: |
| 169 | + if p.grad is not None and self.state[p]['use_muon']: |
| 170 | + if p.grad.is_sparse: |
| 171 | + raise NoSparseGradientError(str(self)) |
| 172 | + params.append(p) |
| 173 | + |
| 174 | + if len(params) == 0: |
| 175 | + continue |
| 176 | + |
| 177 | + lr = group['lr'] |
| 178 | + momentum = group['momentum'] |
| 179 | + |
| 180 | + total_params: int = sum(p.numel() for p in params) |
| 181 | + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) |
| 182 | + curr_idx: int = 0 |
| 183 | + |
| 184 | + for i, p in enumerate(params): |
| 185 | + if i % self.world_size != self.rank: |
| 186 | + curr_idx += p.numel() |
| 187 | + continue |
| 188 | + |
| 189 | + g = p.grad |
| 190 | + if g.ndim > 2: |
| 191 | + g = g.view(g.size(0), -1) |
| 192 | + |
| 193 | + state = self.state[p] |
| 194 | + if 'momentum_buffer' not in state: |
| 195 | + state['momentum_buffer'] = torch.zeros_like(g) |
| 196 | + |
| 197 | + buf = state['momentum_buffer'] |
| 198 | + buf.mul_(momentum).add_(g) |
| 199 | + |
| 200 | + if group['nesterov']: |
| 201 | + g.add_(buf, alpha=momentum) |
| 202 | + else: |
| 203 | + g = buf |
| 204 | + |
| 205 | + g = zero_power_via_newton_schulz_5(g, num_steps=group['ns_steps']) |
| 206 | + g.mul_(max(1.0, g.size(0) / g.size(1)) ** 0.5) |
| 207 | + |
| 208 | + updates_flat[curr_idx:curr_idx + p.numel()] = g.flatten() # fmt: skip |
| 209 | + |
| 210 | + if self.world_size > 1: # pragma: no cover |
| 211 | + all_reduce(updates_flat, op=ReduceOp.SUM) |
| 212 | + |
| 213 | + curr_idx: int = 0 |
| 214 | + for p in params: |
| 215 | + g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p).type_as(p) # fmt: skip |
| 216 | + p.add_(g, alpha=-lr) |
| 217 | + curr_idx += p.numel() |
| 218 | + |
| 219 | + params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']] |
| 220 | + |
| 221 | + lr: float = group['adamw_lr_ratio'] * group['lr'] |
| 222 | + beta1, beta2 = group['adamw_betas'] |
| 223 | + |
| 224 | + bias_correction1: float = self.debias(beta1, group['step']) |
| 225 | + bias_correction2: float = self.debias(beta2, group['step']) |
| 226 | + scale: float = bias_correction1 / bias_correction2 ** 0.5 # fmt: skip |
| 227 | + step_size: float = lr / scale |
| 228 | + |
| 229 | + for p in params: |
| 230 | + grad = p.grad |
| 231 | + state = self.state[p] |
| 232 | + if 'moment1' not in state: |
| 233 | + state['moment1'] = torch.zeros_like(grad) |
| 234 | + state['moment2'] = torch.zeros_like(grad) |
| 235 | + |
| 236 | + buf1, buf2 = state['moment1'], state['moment2'] |
| 237 | + buf1.lerp_(grad, weight=1.0 - beta1) |
| 238 | + buf2.lerp_(grad.square(), weight=1.0 - beta2) |
| 239 | + |
| 240 | + update = buf1 / buf2.sqrt().add_(group['adamw_eps']) |
| 241 | + |
| 242 | + self.apply_weight_decay( |
| 243 | + p, |
| 244 | + grad, |
| 245 | + lr=lr, |
| 246 | + weight_decay=group['adamw_wd'], |
| 247 | + weight_decouple=True, |
| 248 | + fixed_decay=False, |
| 249 | + ) |
| 250 | + |
| 251 | + p.add_(update, alpha=-step_size) |
| 252 | + |
| 253 | + return loss |
0 commit comments