Skip to content

Commit 5c18933

Browse files
authored
Merge pull request #176 from i404788/main
Implement SophiaH & AdaHessian
2 parents c78dcd6 + d0a9b1b commit 5c18933

File tree

10 files changed

+435
-10
lines changed

10 files changed

+435
-10
lines changed

pytorch_optimizer/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from pytorch_optimizer.optimizer.ranger21 import Ranger21
6363
from pytorch_optimizer.optimizer.rotograd import RotoGrad
6464
from pytorch_optimizer.optimizer.sam import SAM
65-
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD
65+
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
6666
from pytorch_optimizer.optimizer.sgdp import SGDP
6767
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
6868
from pytorch_optimizer.optimizer.shampoo_utils import (
@@ -83,6 +83,8 @@
8383
from pytorch_optimizer.optimizer.sm3 import SM3
8484
from pytorch_optimizer.optimizer.srmm import SRMM
8585
from pytorch_optimizer.optimizer.swats import SWATS
86+
from pytorch_optimizer.optimizer.adahessian import AdaHessian
87+
from pytorch_optimizer.optimizer.sophiah import SophiaH
8688
from pytorch_optimizer.optimizer.utils import (
8789
clip_grad_norm,
8890
disable_running_stats,
@@ -147,6 +149,9 @@
147149
AdaShift,
148150
AdaDelta,
149151
Amos,
152+
AdaHessian,
153+
SophiaH,
154+
SignSGD
150155
]
151156
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
152157

pytorch_optimizer/base/optimizer.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,79 @@
44

55
import torch
66

7-
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
8-
from pytorch_optimizer.base.types import BETAS
7+
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError, NoSparseGradientError
8+
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G
99

1010

1111
class BaseOptimizer(ABC):
1212
r"""Base optimizer class."""
1313

14+
@torch.no_grad()
15+
def set_hessian(self, hessian):
16+
"""
17+
Helper function to set hessian state from external source
18+
Generally useful when using functorch as a base
19+
20+
Example usage:
21+
```
22+
# Hutchinsons Estimator using HVP
23+
noise = tree_map(lambda v: torch.randn_like(v), params)
24+
loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
25+
hessian_diag_est = tree_map(lambda a, b: a*b, hvp_est, noise)
26+
27+
optimizer.set_hessian(hessian_diag_est)
28+
# OR
29+
optimizer.step(hessian=hessian_diag_est)
30+
````
31+
32+
"""
33+
i = 0
34+
for group in self.param_groups:
35+
for p in group['params']:
36+
assert p.shape == hessian[i].shape
37+
self.state[p]['hessian'] = hessian[i]
38+
i += 1
39+
40+
@torch.no_grad()
41+
def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0, distribution: HUTCHINSON_G = 'gaussian'):
42+
"""
43+
Hutchinsons approximate hessian, added to the state under key 'hessian'
44+
"""
45+
if distribution not in ['gaussian', 'rademacher']:
46+
raise NotImplementedError(f"Hessian with distribution {distribution} is not implemented")
47+
48+
params = []
49+
for group in self.param_groups:
50+
for p in group['params']:
51+
if p.requires_grad and p.grad is not None:
52+
if p.grad.is_sparse:
53+
raise NoSparseGradientError(str(self))
54+
# Initialize Hessian state
55+
if 'hessian' in self.state[p]:
56+
if pre_zero:
57+
self.state[p]['hessian'].zero_()
58+
else:
59+
self.state[p]['hessian'] = torch.zeros_like(p.data)
60+
params.append(p)
61+
62+
if len(params) == 0:
63+
return
64+
65+
grads = [p.grad for p in params]
66+
67+
for i in range(nsamples):
68+
if distribution == 'gaussian':
69+
# Gaussian N(0,Id)
70+
zs = [torch.randn(p.size(), device=p.device) for p in params]
71+
elif distribution == 'rademacher':
72+
# Rademacher distribution {-1.0, 1.0}
73+
zs = [torch.randint(0, 2, p.size(), dtype=p.dtype, device=p.device) * 2.0 - 1.0 for p in params]
74+
75+
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < nsamples - 1)
76+
for h_z, z, p in zip(h_zs, zs, params):
77+
# approximate the expected values of z*(H@z)
78+
self.state[p]['hessian'].add_(h_z * z, alpha=(1/nsamples) * alpha)
79+
1480
@staticmethod
1581
def apply_weight_decay(
1682
p: torch.Tensor,

pytorch_optimizer/base/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
1+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union, Literal
22

33
import torch
44
from torch.optim import Optimizer
@@ -12,3 +12,5 @@
1212
STATE = Dict[str, Any]
1313
OPTIMIZER = Type[Optimizer]
1414
SCHEDULER = Type[_LRScheduler]
15+
16+
HUTCHINSON_G = Literal['gaussian', 'rademacher']
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, HUTCHINSON_G
7+
8+
# Modified from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py (MIT David Samuel)
9+
10+
11+
class AdaHessian(Optimizer, BaseOptimizer):
12+
r"""An Adaptive Second Order Optimizer for Machine Learning
13+
14+
Requires `loss.backward(create_graph=True)` in order to calculate hessians
15+
16+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
17+
:param lr: float. learning rate.
18+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
19+
:param weight_decay: float. weight decay (L2 penalty).
20+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
21+
:param fixed_decay: bool. fix weight decay.
22+
:param hessian_power: float. exponent of the hessian trace
23+
:param update_period: int. number of steps after which to apply hessian approximation
24+
:param n_samples: int. times to sample `z` for the approximation of the hessian trace
25+
:param eps: float. term added to the denominator to improve numerical stability.
26+
"""
27+
28+
def __init__(self,
29+
params: PARAMETERS,
30+
lr: float = 1e-1,
31+
betas: BETAS = (0.9, 0.999),
32+
weight_decay: float = 0.0,
33+
weight_decouple: bool = True,
34+
fixed_decay: bool = False,
35+
hessian_power: float = 1.0,
36+
update_period: int = 1,
37+
n_samples: int = 1,
38+
hessian_distribution: HUTCHINSON_G = 'rademacher',
39+
eps: float = 1e-16):
40+
41+
self.validate_learning_rate(lr)
42+
self.validate_betas(betas)
43+
self.validate_non_negative(weight_decay, 'weight_decay')
44+
self.validate_non_negative(eps, 'eps')
45+
self.validate_range(hessian_power, "Hessian Power", 0, 1, range_type='(]')
46+
47+
self.distribution = hessian_distribution
48+
self.update_period = update_period
49+
self.n_samples = n_samples
50+
defaults: DEFAULTS = {
51+
'lr': lr,
52+
'betas': betas,
53+
'weight_decay': weight_decay,
54+
'weight_decouple': weight_decouple,
55+
'fixed_decay': fixed_decay,
56+
'hessian_power': hessian_power,
57+
'eps': eps,
58+
}
59+
self._step = 0
60+
super().__init__(params, defaults)
61+
62+
@torch.no_grad()
63+
def reset(self):
64+
self._step = 0
65+
for group in self.param_groups:
66+
for p in group['params']:
67+
state = self.state[p]
68+
state['exp_avg'] = torch.zeros_like(p)
69+
state['exp_hessian_diag_sq'] = torch.zero_like(p)
70+
71+
@torch.no_grad()
72+
def step(self, closure: CLOSURE = None, hessian: tuple[torch.Tensor] = None) -> LOSS:
73+
loss: LOSS = None
74+
if closure is not None:
75+
with torch.enable_grad():
76+
loss = closure()
77+
78+
if hessian is not None:
79+
self.set_hessian(hessian)
80+
elif self._step % self.update_period == 0:
81+
self.compute_hutchinson_hessian(self.n_samples, distribution=self.distribution)
82+
83+
for group in self.param_groups:
84+
for p in group['params']:
85+
if p.grad is None:
86+
continue
87+
88+
grad = p.grad
89+
if grad.is_sparse:
90+
raise NoSparseGradientError(str(self))
91+
92+
# State initialization
93+
state = self.state[p]
94+
if 'exp_avg' not in state:
95+
state['exp_avg'] = torch.zeros_like(p.data)
96+
state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)
97+
98+
self.apply_weight_decay(
99+
p=p,
100+
grad=grad,
101+
lr=group['lr'],
102+
weight_decay=group['weight_decay'],
103+
weight_decouple=group['weight_decouple'],
104+
fixed_decay=group['fixed_decay'],
105+
)
106+
107+
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
108+
beta1, beta2 = group['betas']
109+
110+
# Decay the first and second moment running average coefficient
111+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
112+
if (self._step % self.update_period == 0 or hessian is not None) and 'hessian' in state:
113+
# if self.average_conv_kernel and p.dim() == 4:
114+
# state['hessian'] = torch.abs(state['hessian']).mean(dim=[2, 3], keepdim=True).expand_as(state['hessian']).clone()
115+
exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1 - beta2)
116+
117+
bias_correction1 = 1 - beta1 ** (self._step+1)
118+
bias_correction2 = 1 - beta2 ** (self._step+1)
119+
120+
k = group['hessian_power']
121+
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
122+
123+
# make update
124+
step_size = group['lr'] / bias_correction1
125+
p.addcdiv_(exp_avg, denom, value=-step_size)
126+
127+
self._step += 1
128+
return loss

pytorch_optimizer/optimizer/sgd.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,75 @@ def step(self, closure: CLOSURE = None) -> LOSS:
311311
p.add_(grad, alpha=-new_lr)
312312

313313
return loss
314+
315+
316+
class SignSGD(Optimizer, BaseOptimizer):
317+
r"""SignSGD: Compressed Optimisation for Non-Convex Problems
318+
319+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
320+
:param lr: float. learning rate.
321+
:param momentum: float. momentum factor (0.0=SignSGD, >0=Signum).
322+
:param weight_decay: float. weight decay (L2 penalty).
323+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
324+
"""
325+
326+
def __init__(
327+
self,
328+
params: PARAMETERS,
329+
lr: float = 1e-3,
330+
beta: float = 0.9,
331+
weight_decay: float = 0.0,
332+
weight_decouple: bool = True,
333+
):
334+
self.validate_learning_rate(lr)
335+
self.validate_range(beta, 'beta', 0.0, 1.0)
336+
self.validate_non_negative(weight_decay, 'weight_decay')
337+
338+
defaults: DEFAULTS = {
339+
'lr': lr,
340+
'beta': beta,
341+
'weight_decay': weight_decay,
342+
'weight_decouple': weight_decouple,
343+
}
344+
345+
super().__init__(params, defaults)
346+
347+
@torch.no_grad()
348+
def reset(self):
349+
for group in self.param_groups:
350+
for p in group['params']:
351+
state = self.state[p]
352+
353+
if group['beta'] > 0.0:
354+
state['momentum_buffer'] = p.grad.clone()
355+
356+
@torch.no_grad()
357+
def step(self, closure: CLOSURE = None) -> LOSS:
358+
loss: LOSS = None
359+
if closure is not None:
360+
with torch.enable_grad():
361+
loss = closure()
362+
363+
for group in self.param_groups:
364+
beta = group['beta']
365+
for p in group['params']:
366+
if p.grad is None:
367+
continue
368+
369+
if p.grad.is_sparse:
370+
raise NoSparseGradientError(str(self))
371+
372+
state = self.state[p]
373+
374+
if beta > 0.0:
375+
if len(state) == 0:
376+
state['momentum_buffer'] = p.grad.clone()
377+
378+
buf = state['momentum_buffer']
379+
buf.mul_(beta).add_(p.grad, alpha=1.0 - beta)
380+
else:
381+
buf = p.grad
382+
383+
p.add_(torch.sign(buf), alpha=-group['lr'])
384+
385+
return loss

0 commit comments

Comments
 (0)