Skip to content

Commit 2ed564d

Browse files
author
ferris
committed
Allow external hessian source & fix remaining adahessian bugs
1 parent e767029 commit 2ed564d

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

pytorch_optimizer/base/optimizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@
1111
class BaseOptimizer(ABC):
1212
r"""Base optimizer class."""
1313

14+
@torch.no_grad()
15+
def set_hessian(self, hessian):
16+
"""
17+
Helper function to set hessian state from external source
18+
Generally useful when using functorch as a base
19+
20+
Example usage:
21+
```
22+
# Hutchinsons Estimator using HVP
23+
noise = tree_map(lambda v: torch.randn_like(v), params)
24+
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
25+
hessian_diag_est = tree_map(lambda a, b: a*b, hvp_est, noise)
26+
27+
optimizer.set_hessian(hessian_diag_est)
28+
# OR
29+
optimizer.step(hessian=hessian_diag_est)
30+
````
31+
32+
"""
33+
i = 0
34+
for group in self.param_groups:
35+
for p in group['params']:
36+
assert p.shape == hessian[i].shape
37+
self.state[p]['hessian'] = hessian[i]
38+
i += 1
39+
1440
@torch.no_grad()
1541
def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0, distribution: HUTCHINSON_G = 'gaussian'):
1642
"""

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self,
4444
self.validate_non_negative(eps, 'eps')
4545
self.validate_range(hessian_power, "Hessian Power", 0, 1, range_type='(]')
4646

47+
self.distribution = hessian_distribution
4748
self.update_period = update_period
4849
self.n_samples = n_samples
4950
defaults: DEFAULTS = {
@@ -65,7 +66,7 @@ def reset(self):
6566
for p in group['params']:
6667
state = self.state[p]
6768
state['exp_avg'] = torch.zeros_like(p)
68-
state['exp_hessian_diag_sq'] = state['hessian'].clone()
69+
state['exp_hessian_diag_sq'] = torch.zero_like(p)
6970

7071
@torch.no_grad()
7172
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -75,7 +76,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7576
loss = closure()
7677

7778
if self._step % self.update_period == 0:
78-
self.compute_hutchinson_hessian(self.n_samples)
79+
self.compute_hutchinson_hessian(self.n_samples, distribution=self.distribution)
7980

8081
for group in self.param_groups:
8182
for p in group['params']:
@@ -90,8 +91,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9091
state = self.state[p]
9192
if 'exp_avg' not in state:
9293
state['exp_avg'] = torch.zeros_like(p.data)
93-
# NOTE: zeroing-out the hessian causes instability
94-
state['exp_hessian_diag_sq'] = state['hessian'].clone()
94+
state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)
9595

9696
self.apply_weight_decay(
9797
p=p,

pytorch_optimizer/optimizer/sophiah.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ class SophiaH(Optimizer, BaseOptimizer):
2525

2626
def __init__(self,
2727
params: PARAMETERS,
28-
lr: float = 1e-1,
29-
betas: BETAS = (0.965, 0.99),
28+
lr: float = 6e-2,
29+
betas: BETAS = (0.96, 0.99),
3030
weight_decay: float = 0.0,
3131
weight_decouple: bool = True,
3232
fixed_decay: bool = False,
33-
p: float = 25.,
33+
p: float = 1e-2,
3434
update_period: int = 10,
3535
n_samples: int = 1,
3636
hessian_distribution: HUTCHINSON_G = 'gaussian',
@@ -40,8 +40,9 @@ def __init__(self,
4040
self.validate_betas(betas)
4141
self.validate_non_negative(weight_decay, 'weight_decay')
4242
self.validate_non_negative(eps, 'eps')
43-
self.validate_positive(p, "p (gradient clip)")
43+
self.validate_non_negative(p, "p (gradient clip)")
4444

45+
self.distribution = hessian_distribution
4546
defaults: DEFAULTS = {
4647
'lr': lr,
4748
'betas': betas,
@@ -66,14 +67,16 @@ def reset(self):
6667
state['hessian_moment'] = torch.zeros_like(p)
6768

6869
@torch.no_grad()
69-
def step(self, closure: CLOSURE = None) -> LOSS:
70+
def step(self, closure: CLOSURE = None, hessian: tuple[torch.Tensor] = None) -> LOSS:
7071
loss: LOSS = None
7172
if closure is not None:
7273
with torch.enable_grad():
7374
loss = closure()
7475

75-
if self._step % self.update_period == 0:
76-
self.compute_hutchinson_hessian(self.n_samples)
76+
if hessian is not None:
77+
self.set_hessian(hessian)
78+
elif self._step % self.update_period == 0:
79+
self.compute_hutchinson_hessian(self.n_samples, distribution=self.distribution)
7780

7881
for group in self.param_groups:
7982
for p in group['params']:
@@ -103,7 +106,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
103106
momentum, hessian_moment = state['momentum'], state['hessian_moment']
104107

105108
momentum.mul_(beta1).add_(p.grad, alpha=1.0-beta1)
106-
if self._step % self.update_period == 0:
109+
if self._step % self.update_period == 0 or hessian is not None:
107110
hessian_moment.mul_(beta2).add_(state['hessian'], alpha=1.0-beta2)
108111

109112
# See https://shreyansh26.github.io/post/2023-05-28_sophia_scalable_second_order_optimizer_llms/#per-coordinate-clipping

0 commit comments

Comments
 (0)