|
| 1 | +import math |
| 2 | +from typing import Literal, Optional, Tuple, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.optim.optimizer import Optimizer |
| 6 | + |
| 7 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 8 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 9 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 10 | + |
| 11 | +PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full'] |
| 12 | + |
| 13 | + |
| 14 | +class GaLoreProjector: |
| 15 | + r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection. |
| 16 | +
|
| 17 | + :param rank: int. low rank to project. |
| 18 | + :param update_proj_gap: int. num steps to update the projection. |
| 19 | + :param scale: float. scale factor. |
| 20 | + :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are |
| 21 | + supported. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__( |
| 25 | + self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std' |
| 26 | + ): |
| 27 | + self.rank = rank |
| 28 | + self.update_proj_gap = update_proj_gap |
| 29 | + self.scale = scale |
| 30 | + self.projection_type = projection_type |
| 31 | + |
| 32 | + self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None |
| 33 | + |
| 34 | + @staticmethod |
| 35 | + def get_orthogonal_matrix( |
| 36 | + weights: torch.Tensor, rank: int, projection_type: str |
| 37 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| 38 | + if projection_type not in {'right', 'left', 'full'}: |
| 39 | + raise ValueError('projection_type should be one of left, right or full') |
| 40 | + |
| 41 | + original_type = weights.data.dtype |
| 42 | + original_device = weights.data.device |
| 43 | + is_float: bool = original_type == torch.float |
| 44 | + |
| 45 | + u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False) |
| 46 | + |
| 47 | + if projection_type == 'right': |
| 48 | + b = vh[:rank, :] |
| 49 | + return b if is_float else b.to(original_device).type(original_type) |
| 50 | + if projection_type == 'left': |
| 51 | + a = u[:, :rank] |
| 52 | + return a if is_float else a.to(original_device).type(original_type) |
| 53 | + |
| 54 | + a = u[:, :rank] |
| 55 | + b = vh[:rank, :] |
| 56 | + |
| 57 | + return ( |
| 58 | + (a, b) |
| 59 | + if is_float |
| 60 | + else (a.to(original_device).type(original_type), b.to(original_device).type(original_type)) |
| 61 | + ) |
| 62 | + |
| 63 | + def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 64 | + if grad.shape[0] >= grad.shape[1]: |
| 65 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 66 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') |
| 67 | + return torch.matmul(grad, self.ortho_matrix.t()) |
| 68 | + |
| 69 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 70 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') |
| 71 | + |
| 72 | + return torch.matmul(self.ortho_matrix.t(), grad) |
| 73 | + |
| 74 | + def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 75 | + if grad.shape[0] >= grad.shape[1]: |
| 76 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 77 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') |
| 78 | + return torch.matmul(self.ortho_matrix.t(), grad) |
| 79 | + |
| 80 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 81 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') |
| 82 | + |
| 83 | + return torch.matmul(grad, self.ortho_matrix.t()) |
| 84 | + |
| 85 | + def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 86 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 87 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right') |
| 88 | + return torch.matmul(grad, self.ortho_matrix.t()) |
| 89 | + |
| 90 | + def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 91 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 92 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left') |
| 93 | + return torch.matmul(self.ortho_matrix.t(), grad) |
| 94 | + |
| 95 | + def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 96 | + if self.ortho_matrix is None or steps % self.update_proj_gap == 0: |
| 97 | + self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full') |
| 98 | + return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t() |
| 99 | + |
| 100 | + def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor: |
| 101 | + if self.projection_type == 'std': |
| 102 | + return self.get_low_rank_grad_std(full_rank_grad, steps) |
| 103 | + if self.projection_type == 'reverse_std': |
| 104 | + return self.get_low_rank_grad_reverse_std(full_rank_grad, steps) |
| 105 | + if self.projection_type == 'right': |
| 106 | + return self.get_low_rank_grad_right(full_rank_grad, steps) |
| 107 | + if self.projection_type == 'left': |
| 108 | + return self.get_low_rank_grad_left(full_rank_grad, steps) |
| 109 | + if self.projection_type == 'full': |
| 110 | + return self.get_low_rank_grad_full(full_rank_grad, steps) |
| 111 | + raise NotImplementedError |
| 112 | + |
| 113 | + def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor: |
| 114 | + if self.projection_type == 'std': |
| 115 | + return ( |
| 116 | + torch.matmul(low_rank_grad, self.ortho_matrix) |
| 117 | + if low_rank_grad.shape[0] >= low_rank_grad.shape[1] |
| 118 | + else torch.matmul(self.ortho_matrix, low_rank_grad) |
| 119 | + ) * self.scale |
| 120 | + if self.projection_type == 'reverse_std': |
| 121 | + return ( |
| 122 | + torch.matmul(self.ortho_matrix, low_rank_grad.t()) |
| 123 | + if low_rank_grad.shape[0] <= low_rank_grad.shape[1] |
| 124 | + else torch.matmul(low_rank_grad, self.ortho_matrix.t()) |
| 125 | + ) * self.scale |
| 126 | + if self.projection_type == 'right': |
| 127 | + return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale |
| 128 | + if self.projection_type == 'left': |
| 129 | + return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale |
| 130 | + if self.projection_type == 'full': |
| 131 | + return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale |
| 132 | + |
| 133 | + raise NotImplementedError |
| 134 | + |
| 135 | + |
| 136 | +class GaLore(Optimizer, BaseOptimizer): |
| 137 | + r"""AdamW optimizer with GaLore projector. |
| 138 | +
|
| 139 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. |
| 140 | + :param lr: float. learning rate. |
| 141 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. |
| 142 | + :param weight_decay: float. weight decay (L2 penalty). |
| 143 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 144 | + """ |
| 145 | + |
| 146 | + def __init__( |
| 147 | + self, |
| 148 | + params: PARAMETERS, |
| 149 | + lr: float = 1e-3, |
| 150 | + betas: BETAS = (0.9, 0.999), |
| 151 | + weight_decay: float = 0.0, |
| 152 | + eps: float = 1e-6, |
| 153 | + **kwargs, |
| 154 | + ): |
| 155 | + self.validate_learning_rate(lr) |
| 156 | + self.validate_betas(betas) |
| 157 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 158 | + self.validate_non_negative(eps, 'eps') |
| 159 | + |
| 160 | + defaults: DEFAULTS = { |
| 161 | + 'lr': lr, |
| 162 | + 'betas': betas, |
| 163 | + 'weight_decay': weight_decay, |
| 164 | + 'eps': eps, |
| 165 | + **kwargs, |
| 166 | + } |
| 167 | + |
| 168 | + super().__init__(params, defaults) |
| 169 | + |
| 170 | + def __str__(self) -> str: |
| 171 | + return 'GaLore' |
| 172 | + |
| 173 | + @torch.no_grad() |
| 174 | + def reset(self): |
| 175 | + for group in self.param_groups: |
| 176 | + for p in group['params']: |
| 177 | + state = self.state[p] |
| 178 | + |
| 179 | + state['exp_avg'] = torch.zeros_like(p) |
| 180 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 181 | + |
| 182 | + @torch.no_grad() |
| 183 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 184 | + loss: LOSS = None |
| 185 | + if closure is not None: |
| 186 | + with torch.enable_grad(): |
| 187 | + loss = closure() |
| 188 | + |
| 189 | + for group in self.param_groups: |
| 190 | + if 'step' in group: |
| 191 | + group['step'] += 1 |
| 192 | + else: |
| 193 | + group['step'] = 1 |
| 194 | + |
| 195 | + beta1, beta2 = group['betas'] |
| 196 | + |
| 197 | + bias_correction1: float = 1.0 - beta1 ** group['step'] |
| 198 | + bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step']) |
| 199 | + |
| 200 | + step_size: float = group['lr'] * bias_correction2_sq / bias_correction1 |
| 201 | + |
| 202 | + for p in group['params']: |
| 203 | + if p.grad is None: |
| 204 | + continue |
| 205 | + |
| 206 | + grad = p.grad |
| 207 | + if grad.is_sparse: |
| 208 | + raise NoSparseGradientError(str(self)) |
| 209 | + |
| 210 | + state = self.state[p] |
| 211 | + |
| 212 | + if len(state) == 0: |
| 213 | + state['exp_avg'] = torch.zeros_like(p) |
| 214 | + state['exp_avg_sq'] = torch.zeros_like(p) |
| 215 | + |
| 216 | + if 'rank' in group and p.dim() > 1: |
| 217 | + if 'projector' not in state: |
| 218 | + state['projector'] = GaLoreProjector( |
| 219 | + rank=group['rank'], |
| 220 | + update_proj_gap=group['update_proj_gap'], |
| 221 | + scale=group['scale'], |
| 222 | + projection_type=group['projection_type'], |
| 223 | + ) |
| 224 | + |
| 225 | + grad = state['projector'].project(grad, group['step']) |
| 226 | + |
| 227 | + self.apply_weight_decay( |
| 228 | + p=p, |
| 229 | + grad=None, |
| 230 | + lr=group['lr'], |
| 231 | + weight_decay=group['weight_decay'], |
| 232 | + weight_decouple=True, |
| 233 | + fixed_decay=False, |
| 234 | + ) |
| 235 | + |
| 236 | + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 237 | + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) |
| 238 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
| 239 | + |
| 240 | + de_nom = exp_avg_sq.sqrt().add_(group['eps']) |
| 241 | + |
| 242 | + norm_grad = exp_avg / de_nom |
| 243 | + |
| 244 | + if 'rank' in group and p.dim() > 1: |
| 245 | + norm_grad = state['projector'].project_back(norm_grad) |
| 246 | + |
| 247 | + p.add_(norm_grad, alpha=-step_size) |
| 248 | + |
| 249 | + return loss |
0 commit comments