|
| 1 | +import torch |
| 2 | +from torch.optim.optimizer import Optimizer |
| 3 | + |
| 4 | +from pytorch_optimizer.base.exception import NoSparseGradientError |
| 5 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 6 | +from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 7 | + |
| 8 | + |
| 9 | +class AggMo(Optimizer, BaseOptimizer): |
| 10 | + r"""Aggregated Momentum: Stability Through Passive Damping. |
| 11 | +
|
| 12 | + :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. |
| 13 | + :param lr: float. learning rate. |
| 14 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. |
| 15 | + :param weight_decay: float. weight decay (L2 penalty). |
| 16 | + :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + params: PARAMETERS, |
| 22 | + lr: float = 1e-3, |
| 23 | + betas: BETAS = (0.0, 0.9, 0.99), |
| 24 | + weight_decay: float = 0.0, |
| 25 | + weight_decouple: bool = False, |
| 26 | + ): |
| 27 | + self.lr = lr |
| 28 | + self.betas = betas |
| 29 | + self.weight_decay = weight_decay |
| 30 | + |
| 31 | + self.validate_parameters() |
| 32 | + |
| 33 | + defaults: DEFAULTS = { |
| 34 | + 'lr': lr, |
| 35 | + 'betas': betas, |
| 36 | + 'weight_decay': weight_decay, |
| 37 | + 'weight_decouple': weight_decouple, |
| 38 | + } |
| 39 | + super().__init__(params, defaults) |
| 40 | + |
| 41 | + def validate_parameters(self): |
| 42 | + self.validate_learning_rate(self.lr) |
| 43 | + self.validate_betas(self.betas) |
| 44 | + self.validate_weight_decay(self.weight_decay) |
| 45 | + |
| 46 | + def __str__(self) -> str: |
| 47 | + return 'AggMo' |
| 48 | + |
| 49 | + @torch.no_grad() |
| 50 | + def reset(self): |
| 51 | + for group in self.param_groups: |
| 52 | + group['step'] = 0 |
| 53 | + for p in group['params']: |
| 54 | + state = self.state[p] |
| 55 | + |
| 56 | + state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']} |
| 57 | + |
| 58 | + @torch.no_grad() |
| 59 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 60 | + loss: LOSS = None |
| 61 | + if closure is not None: |
| 62 | + with torch.enable_grad(): |
| 63 | + loss = closure() |
| 64 | + |
| 65 | + for group in self.param_groups: |
| 66 | + if 'step' in group: |
| 67 | + group['step'] += 1 |
| 68 | + else: |
| 69 | + group['step'] = 1 |
| 70 | + |
| 71 | + betas = group['betas'] |
| 72 | + |
| 73 | + for p in group['params']: |
| 74 | + if p.grad is None: |
| 75 | + continue |
| 76 | + |
| 77 | + grad = p.grad |
| 78 | + if grad.is_sparse: |
| 79 | + raise NoSparseGradientError(str(self)) |
| 80 | + |
| 81 | + state = self.state[p] |
| 82 | + |
| 83 | + if len(state) == 0: |
| 84 | + state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in betas} |
| 85 | + |
| 86 | + if group['weight_decouple']: |
| 87 | + p.mul_(1.0 - group['weight_decay'] * group['lr']) |
| 88 | + elif group['weight_decay'] > 0.0: |
| 89 | + grad.add_(p, alpha=group['weight_decay']) |
| 90 | + |
| 91 | + for beta in betas: |
| 92 | + buf = state['momentum_buffer'][beta] |
| 93 | + buf.mul_(beta).add_(grad) |
| 94 | + |
| 95 | + p.add_(buf, alpha=-group['lr'] / len(betas)) |
| 96 | + |
| 97 | + return loss |
0 commit comments