|
| 1 | +import torch |
| 2 | +from torch.optim import Optimizer |
| 3 | + |
| 4 | +from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 5 | + |
| 6 | + |
| 7 | +class LARS(Optimizer): |
| 8 | + """ |
| 9 | + Reference : https://github.com/facebookresearch/mae/blob/main/util/lars.py |
| 10 | + Example : |
| 11 | + from pytorch_optimizer import LARS |
| 12 | + ... |
| 13 | + model = YourModel() |
| 14 | + optimizer = LARS(model.parameters()) |
| 15 | + ... |
| 16 | + for input, output in data: |
| 17 | + optimizer.zero_grad() |
| 18 | + loss = loss_function(output, model(input)) |
| 19 | + loss.backward() |
| 20 | + optimizer.step() |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + params: PARAMETERS, |
| 26 | + lr: float = 1e-3, |
| 27 | + weight_decay: float = 0.0, |
| 28 | + momentum: float = 0.9, |
| 29 | + trust_coefficient: float = 0.001, |
| 30 | + eps: float = 1e-6, |
| 31 | + ): |
| 32 | + """LARS optimizer, no rate scaling or weight decay for parameters <= 1D |
| 33 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups |
| 34 | + :param lr: float. learning rate |
| 35 | + :param weight_decay: float. weight decay (L2 penalty) |
| 36 | + :param momentum: float. momentum |
| 37 | + :param trust_coefficient: float. trust_coefficient |
| 38 | + :param eps: float. epsilon |
| 39 | + """ |
| 40 | + self.lr = lr |
| 41 | + self.weight_decay = weight_decay |
| 42 | + self.momentum = momentum |
| 43 | + self.trust_coefficient = trust_coefficient |
| 44 | + self.eps = eps |
| 45 | + |
| 46 | + self.check_valid_parameters() |
| 47 | + |
| 48 | + defaults: DEFAULTS = dict( |
| 49 | + lr=lr, |
| 50 | + weight_decay=weight_decay, |
| 51 | + momentum=momentum, |
| 52 | + trust_coefficient=trust_coefficient, |
| 53 | + ) |
| 54 | + super().__init__(params, defaults) |
| 55 | + |
| 56 | + def check_valid_parameters(self): |
| 57 | + if self.lr < 0.0: |
| 58 | + raise ValueError(f'Invalid learning rate : {self.lr}') |
| 59 | + if self.weight_decay < 0.0: |
| 60 | + raise ValueError(f'Invalid weight_decay : {self.weight_decay}') |
| 61 | + if self.momentum < 0.0: |
| 62 | + raise ValueError(f'Invalid momentum : {self.momentum}') |
| 63 | + if self.trust_coefficient < 0.0: |
| 64 | + raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}') |
| 65 | + if self.eps < 0.0: |
| 66 | + raise ValueError(f'Invalid eps : {self.eps}') |
| 67 | + |
| 68 | + @torch.no_grad() |
| 69 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 70 | + loss: LOSS = None |
| 71 | + if closure is not None: |
| 72 | + loss = closure() |
| 73 | + |
| 74 | + for g in self.param_groups: |
| 75 | + for p in g['params']: |
| 76 | + if p.grad is None: |
| 77 | + continue |
| 78 | + |
| 79 | + if p.grad.data.is_sparse: |
| 80 | + raise RuntimeError('LARS does not support sparse gradients') |
| 81 | + |
| 82 | + dp = p.grad |
| 83 | + |
| 84 | + if p.ndim > 1: # if not normalization gamma/beta or bias |
| 85 | + dp = dp.add(p, alpha=g['weight_decay']) |
| 86 | + param_norm = torch.norm(p) |
| 87 | + update_norm = torch.norm(dp) |
| 88 | + one = torch.ones_like(param_norm) |
| 89 | + |
| 90 | + q = torch.where( |
| 91 | + param_norm > 0.0, |
| 92 | + torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one), |
| 93 | + one, |
| 94 | + ) |
| 95 | + dp = dp.mul(q) |
| 96 | + |
| 97 | + param_state = self.state[p] |
| 98 | + if 'mu' not in param_state: |
| 99 | + param_state['mu'] = torch.zeros_like(p) |
| 100 | + |
| 101 | + mu = param_state['mu'] |
| 102 | + mu.mul_(g['momentum']).add_(dp) |
| 103 | + |
| 104 | + p.add_(mu, alpha=-g['lr']) |
| 105 | + |
| 106 | + return loss |
0 commit comments