Skip to content

Commit 2dbad2c

Browse files
author
ferris
committed
Make AdaHessian & SophiaH functional
1 parent a28afe2 commit 2dbad2c

File tree

4 files changed

+38
-15
lines changed

4 files changed

+38
-15
lines changed

pytorch_optimizer/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
from pytorch_optimizer.optimizer.sm3 import SM3
8484
from pytorch_optimizer.optimizer.srmm import SRMM
8585
from pytorch_optimizer.optimizer.swats import SWATS
86+
from pytorch_optimizer.optimizer.adahessian import AdaHessian
87+
from pytorch_optimizer.optimizer.sophiah import SophiaH
8688
from pytorch_optimizer.optimizer.utils import (
8789
clip_grad_norm,
8890
disable_running_stats,
@@ -147,6 +149,8 @@
147149
AdaShift,
148150
AdaDelta,
149151
Amos,
152+
AdaHessian,
153+
SophiaH
150154
]
151155
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
152156

pytorch_optimizer/base/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0
3434
grads = [p.grad for p in params]
3535

3636
for i in range(nsamples):
37+
# Gaussian N(0,Id)
38+
zs = [torch.randn(p.size(), device=p.device) for p in params]
3739
# Rademacher distribution {-1.0, 1.0}
38-
zs = [torch.randint(0, 2, p.size(), device=p.device) * 2.0 - 1.0 for p in params]
39-
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < nsamples - 1)
40+
# zs = [torch.randint(0, 2, p.size(), device=p.device) * 2.0 - 1.0 for p in params]
41+
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < nsamples - 1)
4042
for h_z, z, p in zip(h_zs, zs, params):
4143
# approximate the expected values of z*(H@z)
4244
self.state[p]['hessian'].add_(h_z * z, alpha=1/nsamples * alpha)

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,29 @@ def __init__(self,
4343
self.validate_non_negative(eps, 'eps')
4444
self.validate_range(hessian_power, "Hessian Power", 0, 1, range_type='(]')
4545

46+
self.update_period = update_period
47+
self.n_samples = n_samples
4648
defaults: DEFAULTS = {
4749
'lr': lr,
4850
'betas': betas,
4951
'weight_decay': weight_decay,
5052
'weight_decouple': weight_decouple,
5153
'fixed_decay': fixed_decay,
5254
'hessian_power': hessian_power,
53-
'update_period': update_period,
54-
'n_samples': n_samples,
5555
'eps': eps,
5656
}
5757
self._step = 0
5858
super().__init__(params, defaults)
5959

60+
@torch.no_grad()
61+
def reset(self):
62+
self._step = 0
63+
for group in self.param_groups:
64+
for p in group['params']:
65+
state = self.state[p]
66+
state['exp_avg'] = torch.zeros_like(p)
67+
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
68+
6069
@torch.no_grad()
6170
def step(self, closure: CLOSURE = None) -> LOSS:
6271
loss: LOSS = None
@@ -72,9 +81,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7281
if p.grad is None:
7382
continue
7483

75-
if self.average_conv_kernel and p.dim() == 4:
76-
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
77-
7884
grad = p.grad
7985
if grad.is_sparse:
8086
raise NoSparseGradientError(str(self))
@@ -100,10 +106,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100106
# Decay the first and second moment running average coefficient
101107
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
102108
if self._step % self.update_period == 0:
109+
# if self.average_conv_kernel and p.dim() == 4:
110+
# state['hessian'] = torch.abs(state['hessian']).mean(dim=[2, 3], keepdim=True).expand_as(state['hessian']).clone()
103111
exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1 - beta2)
104112

105-
bias_correction1 = 1 - beta1 ** self._step
106-
bias_correction2 = 1 - beta2 ** self._step
113+
bias_correction1 = 1 - beta1 ** (self._step+1)
114+
bias_correction2 = 1 - beta2 ** (self._step+1)
107115

108116
k = group['hessian_power']
109117
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])

pytorch_optimizer/optimizer/sophiah.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,22 @@ def __init__(self,
4848
'weight_decouple': weight_decouple,
4949
'fixed_decay': fixed_decay,
5050
'p': p,
51-
'update_period': update_period,
52-
'n_samples': n_samples,
5351
'eps': eps,
5452
}
53+
self.n_samples = n_samples
54+
self.update_period = update_period
5555
self._step = 0
5656
super().__init__(params, defaults)
5757

58+
@torch.no_grad()
59+
def reset(self):
60+
self._step = 0
61+
for group in self.param_groups:
62+
for p in group['params']:
63+
state = self.state[p]
64+
state['momentum'] = torch.zeros_like(p)
65+
state['hessian_moment'] = torch.zeros_like(p)
66+
5867
@torch.no_grad()
5968
def step(self, closure: CLOSURE = None) -> LOSS:
6069
loss: LOSS = None
@@ -63,7 +72,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
6372
loss = closure()
6473

6574
if self._step % self.update_period == 0:
66-
self.compute_hutchinson_hessian(self.n_smaples)
75+
self.compute_hutchinson_hessian(self.n_samples)
6776

6877
for group in self.param_groups:
6978
for p in group['params']:
@@ -77,8 +86,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7786
# State initialization
7887
state = self.state[p]
7988
if 'momentum' not in state:
80-
state['momentum'] = torch.zeros_like(p.data)
81-
state['hessian_moment'] = torch.zeros_like(p.data)
89+
state['momentum'] = torch.zeros_like(p)
90+
state['hessian_moment'] = torch.zeros_like(p)
8291

8392
self.apply_weight_decay(
8493
p=p,
@@ -100,7 +109,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100109
# The official implementation uses a different method to achieve the same thing (might be faster?):
101110
# https://github.com/Liuhong99/Sophia/blob/bff9df9b584e2084fe037af1ab38f4db31f0acca/sophia.py#L201
102111
update = torch.clip(momentum/torch.clip(hessian_moment, group['eps']), -group['p'], group['p'])
103-
p.add_(update, value=-group['lr'])
112+
p.add_(update, alpha=-group['lr'])
104113

105114
self._step += 1
106115
return loss

0 commit comments

Comments
 (0)