|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from pytorch_optimizer.base.exception import NoComplexParameterError, NoSparseGradientError |
| 6 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 7 | +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, GROUP, LOSS, PARAMETERS |
| 8 | + |
| 9 | + |
| 10 | +def closest_smaller_divisor_of_n_to_k(n: int, k: int) -> int: |
| 11 | + r"""Get closest smaller divisor of n to k.""" |
| 12 | + if n % k == 0: |
| 13 | + return k |
| 14 | + |
| 15 | + if n <= 1 or k <= 1: |
| 16 | + raise ValueError |
| 17 | + |
| 18 | + for i in range(k, 0, -1): |
| 19 | + if n % i == 0: |
| 20 | + return i |
| 21 | + return -1 # pragma: no cover |
| 22 | + |
| 23 | + |
| 24 | +class AdamWSN(BaseOptimizer): |
| 25 | + r"""Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees. |
| 26 | +
|
| 27 | + .. code-block:: python |
| 28 | +
|
| 29 | + sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)] |
| 30 | + sn_param_ids = [id(p) for p in sn_params] |
| 31 | + regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids] |
| 32 | + param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}] |
| 33 | + optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size) |
| 34 | +
|
| 35 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. |
| 36 | + :param lr: float. learning rate. |
| 37 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. |
| 38 | + :param weight_decay: float. weight decay (L2 penalty). |
| 39 | + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. |
| 40 | + :param fixed_decay: bool. fix weight decay. |
| 41 | + :param subset_size: int. If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where |
| 42 | + d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B and |
| 43 | + so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1 to use |
| 44 | + the recommended subset size as stated above. |
| 45 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 46 | + :param maximize: bool. maximize the objective with respect to the params, instead of minimizing. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + params: PARAMETERS, |
| 52 | + lr: float = 1e-3, |
| 53 | + betas: BETAS = (0.9, 0.999), |
| 54 | + weight_decay: float = 0.0, |
| 55 | + weight_decouple: bool = True, |
| 56 | + fixed_decay: bool = False, |
| 57 | + subset_size: int = -1, |
| 58 | + eps: float = 1e-8, |
| 59 | + maximize: bool = False, |
| 60 | + **kwargs, |
| 61 | + ): |
| 62 | + self.validate_learning_rate(lr) |
| 63 | + self.validate_betas(betas) |
| 64 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 65 | + self.validate_non_negative(eps, 'eps') |
| 66 | + |
| 67 | + self.maximize = maximize |
| 68 | + |
| 69 | + defaults: DEFAULTS = { |
| 70 | + 'lr': lr, |
| 71 | + 'betas': betas, |
| 72 | + 'weight_decay': weight_decay, |
| 73 | + 'weight_decouple': weight_decouple, |
| 74 | + 'fixed_decay': fixed_decay, |
| 75 | + 'subset_size': subset_size, |
| 76 | + 'eps': eps, |
| 77 | + **kwargs, |
| 78 | + } |
| 79 | + |
| 80 | + super().__init__(params, defaults) |
| 81 | + |
| 82 | + def __str__(self) -> str: |
| 83 | + return 'AdamWSN' |
| 84 | + |
| 85 | + def init_group(self, group: GROUP, **kwargs) -> None: |
| 86 | + for p in group['params']: |
| 87 | + if p.grad is None: |
| 88 | + continue |
| 89 | + |
| 90 | + grad = p.grad |
| 91 | + if grad.is_sparse: |
| 92 | + raise NoSparseGradientError(str(self)) |
| 93 | + |
| 94 | + if torch.is_complex(p): |
| 95 | + raise NoComplexParameterError(str(self)) |
| 96 | + |
| 97 | + state = self.state[p] |
| 98 | + |
| 99 | + if len(state) == 0: |
| 100 | + state['exp_avg'] = torch.zeros_like(grad) |
| 101 | + |
| 102 | + if group.get('sn'): |
| 103 | + size: int = grad.numel() |
| 104 | + |
| 105 | + if 'subset_size' not in state: |
| 106 | + state['subset_size'] = closest_smaller_divisor_of_n_to_k( |
| 107 | + size, |
| 108 | + ( |
| 109 | + group['subset_size'] |
| 110 | + if group['subset_size'] > 0 |
| 111 | + else int(math.sqrt(size) / abs(int(group['subset_size']))) |
| 112 | + ), |
| 113 | + ) |
| 114 | + |
| 115 | + reshaped_grad = grad.view(size // state['subset_size'], state['subset_size']) |
| 116 | + second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True) # fmt: skip |
| 117 | + state['exp_avg_sq'] = torch.zeros_like(second_moment_update) |
| 118 | + else: |
| 119 | + state['exp_avg_sq'] = torch.zeros_like(grad) |
| 120 | + |
| 121 | + @torch.no_grad() |
| 122 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 123 | + loss: LOSS = None |
| 124 | + if closure is not None: |
| 125 | + with torch.enable_grad(): |
| 126 | + loss = closure() |
| 127 | + |
| 128 | + for group in self.param_groups: |
| 129 | + if 'step' not in group: |
| 130 | + self.init_group(group) |
| 131 | + group['step'] = 1 |
| 132 | + else: |
| 133 | + group['step'] += 1 |
| 134 | + |
| 135 | + beta1, beta2 = group['betas'] |
| 136 | + |
| 137 | + bias_correction1: float = self.debias(beta1, group['step']) |
| 138 | + bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'])) |
| 139 | + |
| 140 | + step_size: float = group['lr'] * bias_correction2_sq / bias_correction1 |
| 141 | + |
| 142 | + for p in group['params']: |
| 143 | + if p.grad is None: |
| 144 | + continue |
| 145 | + |
| 146 | + grad = p.grad |
| 147 | + size = grad.numel() |
| 148 | + |
| 149 | + self.maximize_gradient(grad, maximize=self.maximize) |
| 150 | + |
| 151 | + state = self.state[p] |
| 152 | + |
| 153 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 154 | + |
| 155 | + if group.get('sn'): |
| 156 | + reshaped_grad = grad.view(size // state['subset_size'], state['subset_size']) |
| 157 | + second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True) # fmt: skip |
| 158 | + else: |
| 159 | + second_moment_update = grad.pow(2) |
| 160 | + |
| 161 | + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) |
| 162 | + exp_avg_sq.mul_(beta2).add_(second_moment_update, alpha=1.0 - beta2) |
| 163 | + |
| 164 | + de_nom = exp_avg_sq.sqrt().add_(group['eps']) |
| 165 | + |
| 166 | + if group.get('sn'): |
| 167 | + numerator = exp_avg.view(size // state['subset_size'], state['subset_size']) |
| 168 | + norm_grad = (numerator / de_nom).reshape(p.shape) |
| 169 | + p.add_(norm_grad, alpha=-step_size) |
| 170 | + else: |
| 171 | + p.addcdiv_(exp_avg, de_nom, value=-step_size) |
| 172 | + |
| 173 | + self.apply_weight_decay( |
| 174 | + p=p, |
| 175 | + grad=grad, |
| 176 | + lr=group['lr'], |
| 177 | + weight_decay=group['weight_decay'], |
| 178 | + weight_decouple=group['weight_decouple'], |
| 179 | + fixed_decay=group['fixed_decay'], |
| 180 | + ) |
| 181 | + |
| 182 | + return loss |
0 commit comments