|
| 1 | +import torch |
| 2 | +from torch.optim import Optimizer |
| 3 | + |
| 4 | +from pytorch_optimizer.types import ( |
| 5 | + BETAS, |
| 6 | + CLOSURE, |
| 7 | + DEFAULT_PARAMETERS, |
| 8 | + LOSS, |
| 9 | + PARAMS, |
| 10 | +) |
| 11 | + |
| 12 | + |
| 13 | +class AdaHessian(Optimizer): |
| 14 | + """ |
| 15 | + Reference : https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py |
| 16 | + Example : |
| 17 | + from pytorch_optimizer import AdaHessian |
| 18 | + ... |
| 19 | + model = YourModel() |
| 20 | + optimizer = AdaHessian(model.parameters()) |
| 21 | + ... |
| 22 | + for input, output in data: |
| 23 | + optimizer.zero_grad() |
| 24 | + loss = loss_function(output, model(input)) |
| 25 | + loss.backward(create_graph=True) # this is the important line! |
| 26 | + optimizer.step() |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + params: PARAMS, |
| 32 | + lr: float = 1e-3, |
| 33 | + betas: BETAS = (0.9, 0.999), |
| 34 | + eps: float = 1e-8, |
| 35 | + weight_decay: float = 0.0, |
| 36 | + hessian_power: float = 1.0, |
| 37 | + update_each: int = 1, |
| 38 | + n_samples: int = 1, |
| 39 | + average_conv_kernel: bool = False, |
| 40 | + seed: int = 2147483647, |
| 41 | + ): |
| 42 | + """ |
| 43 | + :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups |
| 44 | + :param lr: float. learning rate. |
| 45 | + :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace |
| 46 | + :param eps: float. term added to the denominator to improve numerical stability |
| 47 | + :param weight_decay: float. weight decay (L2 penalty) |
| 48 | + :param hessian_power: float. exponent of the hessian trace |
| 49 | + :param update_each: int. compute the hessian trace approximation only after *this* number of steps |
| 50 | + :param n_samples: int. how many times to sample `z` for the approximation of the hessian trace |
| 51 | + :param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper. |
| 52 | + :param seed: int. |
| 53 | + """ |
| 54 | + self.lr = lr |
| 55 | + self.eps = eps |
| 56 | + self.betas = betas |
| 57 | + self.weight_decay = weight_decay |
| 58 | + self.hessian_power = hessian_power |
| 59 | + self.n_samples = n_samples |
| 60 | + self.update_each = update_each |
| 61 | + self.average_conv_kernel = average_conv_kernel |
| 62 | + self.seed = seed |
| 63 | + |
| 64 | + self.check_valid_parameters() |
| 65 | + |
| 66 | + # use a separate generator that deterministically generates the same `z`s across all GPUs |
| 67 | + # in case of distributed training |
| 68 | + self.generator: torch.Generator = torch.Generator().manual_seed( |
| 69 | + self.seed |
| 70 | + ) |
| 71 | + |
| 72 | + defaults: DEFAULT_PARAMETERS = dict( |
| 73 | + lr=lr, |
| 74 | + betas=betas, |
| 75 | + eps=eps, |
| 76 | + weight_decay=weight_decay, |
| 77 | + hessian_power=hessian_power, |
| 78 | + ) |
| 79 | + super().__init__(params, defaults) |
| 80 | + |
| 81 | + for p in self.get_params(): |
| 82 | + p.hess = 0.0 |
| 83 | + self.state[p]['hessian_step'] = 0 |
| 84 | + |
| 85 | + def check_valid_parameters(self): |
| 86 | + if 0.0 > self.lr: |
| 87 | + raise ValueError(f'Invalid learning rate : {self.lr}') |
| 88 | + if 0.0 > self.eps: |
| 89 | + raise ValueError(f'Invalid eps : {self.eps}') |
| 90 | + if 0.0 > self.weight_decay: |
| 91 | + raise ValueError(f'Invalid weight_decay : {self.weight_decay}') |
| 92 | + if not 0.0 <= self.betas[0] < 1.0: |
| 93 | + raise ValueError(f'Invalid beta_0 : {self.betas[0]}') |
| 94 | + if not 0.0 <= self.betas[1] < 1.0: |
| 95 | + raise ValueError(f'Invalid beta_1 : {self.betas[1]}') |
| 96 | + if not 0.0 <= self.hessian_power < 1.0: |
| 97 | + raise ValueError(f'Invalid hessian_power : {self.hessian_power}') |
| 98 | + |
| 99 | + def get_params(self): |
| 100 | + """Gets all parameters in all param_groups with gradients""" |
| 101 | + return ( |
| 102 | + p |
| 103 | + for group in self.param_groups |
| 104 | + for p in group['params'] |
| 105 | + if p.requires_grad |
| 106 | + ) |
| 107 | + |
| 108 | + def zero_hessian(self): |
| 109 | + """Zeros out the accumulated hessian traces.""" |
| 110 | + for p in self.get_params(): |
| 111 | + if ( |
| 112 | + not isinstance(p.hess, float) |
| 113 | + and self.state[p]['hessian_step'] % self.update_each == 0 |
| 114 | + ): |
| 115 | + p.hess.zero_() |
| 116 | + |
| 117 | + @torch.no_grad() |
| 118 | + def set_hessian(self): |
| 119 | + """Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter""" |
| 120 | + params = [] |
| 121 | + for p in filter( |
| 122 | + lambda param: param.grad is not None, self.get_params() |
| 123 | + ): |
| 124 | + # compute the trace only each `update_each` step |
| 125 | + if self.state[p]['hessian_step'] % self.update_each == 0: |
| 126 | + params.append(p) |
| 127 | + self.state[p]['hessian_step'] += 1 |
| 128 | + |
| 129 | + if len(params) == 0: |
| 130 | + return |
| 131 | + |
| 132 | + if self.generator.device != params[0].device: |
| 133 | + # hackish way of casting the generator to the right device |
| 134 | + self.generator = torch.Generator(params[0].device).manual_seed( |
| 135 | + self.seed |
| 136 | + ) |
| 137 | + |
| 138 | + grads = [p.grad for p in params] |
| 139 | + |
| 140 | + for i in range(self.n_samples): |
| 141 | + # Rademacher distribution {-1.0, 1.0} |
| 142 | + zs = [ |
| 143 | + torch.randint( |
| 144 | + 0, 2, p.size(), generator=self.generator, device=p.device |
| 145 | + ) |
| 146 | + * 2.0 |
| 147 | + - 1.0 |
| 148 | + for p in params |
| 149 | + ] |
| 150 | + h_zs = torch.autograd.grad( |
| 151 | + grads, |
| 152 | + params, |
| 153 | + grad_outputs=zs, |
| 154 | + only_inputs=True, |
| 155 | + retain_graph=i < self.n_samples - 1, |
| 156 | + ) |
| 157 | + for h_z, z, p in zip(h_zs, zs, params): |
| 158 | + # approximate the expected values of z * (H@z) |
| 159 | + p.hess += h_z * z / self.n_samples |
| 160 | + |
| 161 | + @torch.no_grad() |
| 162 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 163 | + loss: LOSS = None |
| 164 | + if closure is not None: |
| 165 | + loss = closure() |
| 166 | + |
| 167 | + self.zero_hessian() |
| 168 | + self.set_hessian() |
| 169 | + |
| 170 | + for group in self.param_groups: |
| 171 | + for p in group['params']: |
| 172 | + if p.grad is None or p.hess is None: |
| 173 | + continue |
| 174 | + |
| 175 | + if self.average_conv_kernel and p.dim() == 4: |
| 176 | + p.hess = ( |
| 177 | + torch.abs(p.hess) |
| 178 | + .mean(dim=[2, 3], keepdim=True) |
| 179 | + .expand_as(p.hess) |
| 180 | + .clone() |
| 181 | + ) |
| 182 | + |
| 183 | + # Perform correct step-weight decay as in AdamW |
| 184 | + p.mul_(1 - group['lr'] * group['weight_decay']) |
| 185 | + |
| 186 | + state = self.state[p] |
| 187 | + |
| 188 | + if len(state) == 1: |
| 189 | + state['step'] = 0 |
| 190 | + state['exp_avg'] = torch.zeros_like(p.data) |
| 191 | + state['exp_hessian_diag_sq'] = torch.zeros_like(p.data) |
| 192 | + |
| 193 | + exp_avg, exp_hessian_diag_sq = ( |
| 194 | + state['exp_avg'], |
| 195 | + state['exp_hessian_diag_sq'], |
| 196 | + ) |
| 197 | + beta1, beta2 = group['betas'] |
| 198 | + state['step'] += 1 |
| 199 | + |
| 200 | + # Decay the first and second moment running average coefficient |
| 201 | + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) |
| 202 | + exp_hessian_diag_sq.mul_(beta2).addcmul_( |
| 203 | + p.hess, p.hess, value=1 - beta2 |
| 204 | + ) |
| 205 | + |
| 206 | + bias_correction1 = 1 - beta1 ** state['step'] |
| 207 | + bias_correction2 = 1 - beta2 ** state['step'] |
| 208 | + |
| 209 | + k = group['hessian_power'] |
| 210 | + denom = ( |
| 211 | + (exp_hessian_diag_sq / bias_correction2) |
| 212 | + .pow_(k / 2) |
| 213 | + .add_(group['eps']) |
| 214 | + ) |
| 215 | + |
| 216 | + step_size = group['lr'] / bias_correction1 |
| 217 | + p.addcdiv_(exp_avg, denom, value=-step_size) |
| 218 | + |
| 219 | + return loss |
0 commit comments