|
| 1 | +from typing import Callable, Dict, List, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from pytorch_optimizer.base.optimizer import BaseOptimizer |
| 7 | +from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER |
| 8 | + |
| 9 | + |
| 10 | +def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: |
| 11 | + r"""Implement of the Horner scheme to evaluate a polynomial. |
| 12 | +
|
| 13 | + taken from https://discuss.pytorch.org/t/polynomial-evaluation-by-horner-rule/67124 |
| 14 | +
|
| 15 | + :param x: torch.Tensor. variable. |
| 16 | + :param coef: torch.Tensor. coefficients of the polynomial. |
| 17 | + """ |
| 18 | + result = coef[0].clone() |
| 19 | + |
| 20 | + for c in coef[1:]: |
| 21 | + result = (result * x) + c |
| 22 | + |
| 23 | + return result[0] |
| 24 | + |
| 25 | + |
| 26 | +class ERF1994(nn.Module): |
| 27 | + r"""Implementation of ERF1994. |
| 28 | +
|
| 29 | + :param num_coefs: int. The number of polynomial coefficients to use in the approximation. |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, num_coefs: int = 128) -> None: |
| 33 | + super().__init__() |
| 34 | + |
| 35 | + self.n: int = num_coefs |
| 36 | + |
| 37 | + self.i: torch.Tensor = torch.complex(torch.tensor(0.0), torch.tensor(1.0)) |
| 38 | + self.m = 2 * self.n |
| 39 | + self.m2 = 2 * self.m |
| 40 | + self.k = torch.linspace(-self.m + 1, self.m - 1, self.m2 - 1) |
| 41 | + self.l = torch.sqrt(self.n / torch.sqrt(torch.tensor(2.0))) |
| 42 | + self.theta = self.k * torch.pi / self.m |
| 43 | + self.t = self.l * torch.tan(self.theta / 2.0) |
| 44 | + self.f = torch.exp(-self.t ** 2) * (self.l ** 2 + self.t ** 2) # fmt: skip |
| 45 | + self.a = torch.fft.fft(torch.fft.fftshift(self.f)).real / self.m2 |
| 46 | + self.a = torch.flipud(self.a[1:self.n + 1]) # fmt: skip |
| 47 | + |
| 48 | + def w_algorithm(self, z: torch.Tensor) -> torch.Tensor: |
| 49 | + r"""Compute the Faddeeva function of a complex number. |
| 50 | +
|
| 51 | + :param z: torch.Tensor. A tensor of complex numbers. |
| 52 | + """ |
| 53 | + self.l = self.l.to(z.device) |
| 54 | + self.i = self.i.to(z.device) |
| 55 | + self.a = self.a.to(z.device) |
| 56 | + |
| 57 | + iz = self.i * z |
| 58 | + lp_iz, ln_iz = self.l + iz, self.l - iz |
| 59 | + |
| 60 | + z_ = lp_iz / ln_iz |
| 61 | + p = polyval(z_.unsqueeze(0), self.a) |
| 62 | + return 2 * p / ln_iz.pow(2) + (1.0 / torch.sqrt(torch.tensor(torch.pi))) / ln_iz |
| 63 | + |
| 64 | + def forward(self, z: torch.Tensor) -> torch.Tensor: |
| 65 | + r"""Compute the error function of a complex number. |
| 66 | +
|
| 67 | + :param z: torch.Tensor. A tensor of complex numbers. |
| 68 | + """ |
| 69 | + sign_r = torch.sign(z.real) |
| 70 | + sign_i = torch.sign(z.imag) |
| 71 | + z = torch.complex(torch.abs(z.real), torch.abs(z.imag)) |
| 72 | + out = -torch.exp(torch.log(self.w_algorithm(z * self.i)) - z ** 2) + 1 # fmt: skip |
| 73 | + return torch.complex(out.real * sign_r, out.imag * sign_i) |
| 74 | + |
| 75 | + |
| 76 | +class TRAC(BaseOptimizer): |
| 77 | + r"""A Parameter-Free Optimizer for Lifelong Reinforcement Learning. |
| 78 | +
|
| 79 | + Example: |
| 80 | + ------- |
| 81 | + Here's an example:: |
| 82 | +
|
| 83 | + model = YourModel() |
| 84 | + optimizer = TRAC(AdamW(model.parameters())) |
| 85 | +
|
| 86 | + for input, output in data: |
| 87 | + optimizer.zero_grad() |
| 88 | +
|
| 89 | + loss = loss_fn(model(input), output) |
| 90 | + loss.backward() |
| 91 | +
|
| 92 | + optimizer.step() |
| 93 | +
|
| 94 | + :param optimizer: Optimizer. base optimizer. |
| 95 | + :param betas: List[float]. list of beta values. |
| 96 | + :param num_coefs: int. the number of polynomial coefficients to use in the approximation. |
| 97 | + :param s_prev: float. initial scale value. |
| 98 | + :param eps: float. term added to the denominator to improve numerical stability. |
| 99 | + """ |
| 100 | + |
| 101 | + def __init__( |
| 102 | + self, |
| 103 | + optimizer: OPTIMIZER, |
| 104 | + betas: List[float] = (0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999), |
| 105 | + num_coefs: int = 128, |
| 106 | + s_prev: float = 1e-8, |
| 107 | + eps: float = 1e-8, |
| 108 | + ): |
| 109 | + self.validate_positive(num_coefs, 'num_coefs') |
| 110 | + self.validate_non_negative(s_prev, 's_prev') |
| 111 | + self.validate_non_negative(eps, 'eps') |
| 112 | + |
| 113 | + self._optimizer_step_pre_hooks: Dict[int, Callable] = {} |
| 114 | + self._optimizer_step_post_hooks: Dict[int, Callable] = {} |
| 115 | + |
| 116 | + self.erf = ERF1994(num_coefs=num_coefs) |
| 117 | + self.betas = betas |
| 118 | + self.s_prev = s_prev |
| 119 | + self.eps = eps |
| 120 | + |
| 121 | + self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0))) |
| 122 | + |
| 123 | + self.optimizer = optimizer |
| 124 | + self.defaults: DEFAULTS = optimizer.defaults |
| 125 | + |
| 126 | + def __str__(self) -> str: |
| 127 | + return 'TRAC' |
| 128 | + |
| 129 | + @property |
| 130 | + def param_groups(self): |
| 131 | + return self.optimizer.param_groups |
| 132 | + |
| 133 | + @property |
| 134 | + def state(self): |
| 135 | + return self.optimizer.state |
| 136 | + |
| 137 | + @torch.no_grad() |
| 138 | + def reset(self): |
| 139 | + device = self.param_groups[0]['params'][0].device |
| 140 | + |
| 141 | + self.state['trac'] = { |
| 142 | + 'betas': torch.tensor(self.betas, device=device), |
| 143 | + 's': torch.zeros(len(self.betas), device=device), |
| 144 | + 'variance': torch.zeros(len(self.betas), device=device), |
| 145 | + 'sigma': torch.full((len(self.betas),), 1e-8, device=device), |
| 146 | + 'step': 0, |
| 147 | + } |
| 148 | + |
| 149 | + for group in self.param_groups: |
| 150 | + for p in group['params']: |
| 151 | + self.state['trac'][p] = p.clone() |
| 152 | + |
| 153 | + @torch.no_grad() |
| 154 | + def zero_grad(self) -> None: |
| 155 | + self.optimizer.zero_grad(set_to_none=True) |
| 156 | + |
| 157 | + @torch.no_grad() |
| 158 | + def erf_imag(self, x: torch.Tensor) -> torch.Tensor: |
| 159 | + if not torch.is_floating_point(x): |
| 160 | + x = x.to(torch.float32) |
| 161 | + |
| 162 | + ix = torch.complex(torch.zeros_like(x), x) |
| 163 | + |
| 164 | + return self.erf(ix).imag |
| 165 | + |
| 166 | + @torch.no_grad() |
| 167 | + def backup_params_and_grads(self) -> Tuple[Dict, Dict]: |
| 168 | + updates, grads = {}, {} |
| 169 | + |
| 170 | + for group in self.param_groups: |
| 171 | + for p in group['params']: |
| 172 | + updates[p] = p.clone() |
| 173 | + grads[p] = p.grad.clone() if p.grad is not None else None |
| 174 | + |
| 175 | + return updates, grads |
| 176 | + |
| 177 | + @torch.no_grad() |
| 178 | + def trac_step(self, updates: Dict, grads: Dict) -> None: |
| 179 | + self.state['trac']['step'] += 1 |
| 180 | + |
| 181 | + deltas = {} |
| 182 | + |
| 183 | + device = self.param_groups[0]['params'][0].device |
| 184 | + |
| 185 | + h = torch.zeros((1,), device=device) |
| 186 | + for group in self.param_groups: |
| 187 | + for p in group['params']: |
| 188 | + if grads[p] is None: |
| 189 | + continue |
| 190 | + |
| 191 | + theta_ref = self.state['trac'][p] |
| 192 | + update = updates[p] |
| 193 | + |
| 194 | + deltas[p] = (update - theta_ref) / torch.sum(self.state['trac']['s']).add_(self.eps) |
| 195 | + update.neg_().add_(p) |
| 196 | + |
| 197 | + grad, delta = grads[p], deltas[p] |
| 198 | + |
| 199 | + product = torch.dot(delta.flatten(), grad.flatten()) |
| 200 | + h.add_(product) |
| 201 | + |
| 202 | + delta.add_(update) |
| 203 | + |
| 204 | + s = self.state['trac']['s'] |
| 205 | + betas = self.state['trac']['betas'] |
| 206 | + variance = self.state['trac']['variance'] |
| 207 | + sigma = self.state['trac']['sigma'] |
| 208 | + |
| 209 | + variance.mul_(betas.pow(2)).add_(h.pow(2)) |
| 210 | + sigma.mul_(betas).sub_(h) |
| 211 | + |
| 212 | + s_term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps)) |
| 213 | + s_term.mul_(self.f_term) |
| 214 | + s.copy_(s_term) |
| 215 | + |
| 216 | + scale = max(torch.sum(s), 0.0) |
| 217 | + |
| 218 | + for group in self.param_groups: |
| 219 | + for p in group['params']: |
| 220 | + if grads[p] is None: |
| 221 | + continue |
| 222 | + |
| 223 | + delta = deltas[p] |
| 224 | + delta.mul_(scale).add_(self.state['trac'][p]) |
| 225 | + |
| 226 | + p.copy_(delta) |
| 227 | + |
| 228 | + @torch.no_grad() |
| 229 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 230 | + # TODO: backup is first to get the delta of param and grad, but it does not work. |
| 231 | + with torch.enable_grad(): |
| 232 | + loss = self.optimizer.step(closure) |
| 233 | + |
| 234 | + updates, grads = self.backup_params_and_grads() |
| 235 | + |
| 236 | + if 'trac' not in self.state: |
| 237 | + device = self.param_groups[0]['params'][0].device |
| 238 | + |
| 239 | + self.state['trac'] = { |
| 240 | + 'betas': torch.tensor(self.betas, device=device), |
| 241 | + 's': torch.zeros(len(self.betas), device=device), |
| 242 | + 'variance': torch.zeros(len(self.betas), device=device), |
| 243 | + 'sigma': torch.full((len(self.betas),), 1e-8, device=device), |
| 244 | + 'step': 0, |
| 245 | + } |
| 246 | + |
| 247 | + for group in self.param_groups: |
| 248 | + for p in group['params']: |
| 249 | + self.state['trac'][p] = updates[p].clone() |
| 250 | + |
| 251 | + self.trac_step(updates, grads) |
| 252 | + |
| 253 | + return loss |
0 commit comments