Skip to content

Commit e767029

Browse files
author
ferris
committed
Add SignSGD, allow heesian approximation distribution selection, fix adahessian instability
1 parent 2dbad2c commit e767029

File tree

6 files changed

+99
-15
lines changed

6 files changed

+99
-15
lines changed

pytorch_optimizer/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from pytorch_optimizer.optimizer.ranger21 import Ranger21
6363
from pytorch_optimizer.optimizer.rotograd import RotoGrad
6464
from pytorch_optimizer.optimizer.sam import SAM
65-
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD
65+
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SignSGD
6666
from pytorch_optimizer.optimizer.sgdp import SGDP
6767
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
6868
from pytorch_optimizer.optimizer.shampoo_utils import (
@@ -150,7 +150,8 @@
150150
AdaDelta,
151151
Amos,
152152
AdaHessian,
153-
SophiaH
153+
SophiaH,
154+
SignSGD
154155
]
155156
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
156157

pytorch_optimizer/base/optimizer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import torch
66

77
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
8-
from pytorch_optimizer.base.types import BETAS
8+
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G
99

1010

1111
class BaseOptimizer(ABC):
1212
r"""Base optimizer class."""
1313

1414
@torch.no_grad()
15-
def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0):
15+
def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0, distribution: HUTCHINSON_G = 'gaussian'):
1616
"""
1717
Hutchinsons approximate hessian, added to the state under key 'hessian'
1818
"""
19+
if distribution not in ['gaussian', 'rademacher']:
20+
raise NotImplementedError(f"Hessian with distribution {distribution} is not implemented")
21+
1922
params = []
2023
for group in self.param_groups:
2124
for p in group['params']:
@@ -34,14 +37,17 @@ def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0
3437
grads = [p.grad for p in params]
3538

3639
for i in range(nsamples):
37-
# Gaussian N(0,Id)
38-
zs = [torch.randn(p.size(), device=p.device) for p in params]
39-
# Rademacher distribution {-1.0, 1.0}
40-
# zs = [torch.randint(0, 2, p.size(), device=p.device) * 2.0 - 1.0 for p in params]
40+
if distribution == 'gaussian':
41+
# Gaussian N(0,Id)
42+
zs = [torch.randn(p.size(), device=p.device) for p in params]
43+
elif distribution == 'rademacher':
44+
# Rademacher distribution {-1.0, 1.0}
45+
zs = [torch.randint(0, 2, p.size(), dtype=p.dtype, device=p.device) * 2.0 - 1.0 for p in params]
46+
4147
h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < nsamples - 1)
4248
for h_z, z, p in zip(h_zs, zs, params):
4349
# approximate the expected values of z*(H@z)
44-
self.state[p]['hessian'].add_(h_z * z, alpha=1/nsamples * alpha)
50+
self.state[p]['hessian'].add_(h_z * z, alpha=(1/nsamples) * alpha)
4551

4652
@staticmethod
4753
def apply_weight_decay(

pytorch_optimizer/base/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
1+
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union, Literal
22

33
import torch
44
from torch.optim import Optimizer
@@ -12,3 +12,5 @@
1212
STATE = Dict[str, Any]
1313
OPTIMIZER = Type[Optimizer]
1414
SCHEDULER = Type[_LRScheduler]
15+
16+
HUTCHINSON_G = Literal['gaussian', 'rademacher']

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
55
from pytorch_optimizer.base.optimizer import BaseOptimizer
6-
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, HUTCHINSON_G
77

88
# Modified from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py (MIT David Samuel)
99

@@ -35,6 +35,7 @@ def __init__(self,
3535
hessian_power: float = 1.0,
3636
update_period: int = 1,
3737
n_samples: int = 1,
38+
hessian_distribution: HUTCHINSON_G = 'rademacher',
3839
eps: float = 1e-16):
3940

4041
self.validate_learning_rate(lr)
@@ -64,7 +65,7 @@ def reset(self):
6465
for p in group['params']:
6566
state = self.state[p]
6667
state['exp_avg'] = torch.zeros_like(p)
67-
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
68+
state['exp_hessian_diag_sq'] = state['hessian'].clone()
6869

6970
@torch.no_grad()
7071
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -88,8 +89,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8889
# State initialization
8990
state = self.state[p]
9091
if 'exp_avg' not in state:
91-
state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of gradient values
92-
state['exp_hessian_diag_sq'] = torch.zeros_like(p.data) # Exponential moving average of Hessian diagonal square values
92+
state['exp_avg'] = torch.zeros_like(p.data)
93+
# NOTE: zeroing-out the hessian causes instability
94+
state['exp_hessian_diag_sq'] = state['hessian'].clone()
9395

9496
self.apply_weight_decay(
9597
p=p,

pytorch_optimizer/optimizer/sgd.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,75 @@ def step(self, closure: CLOSURE = None) -> LOSS:
311311
p.add_(grad, alpha=-new_lr)
312312

313313
return loss
314+
315+
316+
class SignSGD(Optimizer, BaseOptimizer):
317+
r"""SignSGD: Compressed Optimisation for Non-Convex Problems
318+
319+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
320+
:param lr: float. learning rate.
321+
:param momentum: float. momentum factor (0.0=SignSGD, >0=Signum).
322+
:param weight_decay: float. weight decay (L2 penalty).
323+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
324+
"""
325+
326+
def __init__(
327+
self,
328+
params: PARAMETERS,
329+
lr: float = 1e-3,
330+
beta: float = 0.9,
331+
weight_decay: float = 0.0,
332+
weight_decouple: bool = True,
333+
):
334+
self.validate_learning_rate(lr)
335+
self.validate_range(beta, 'beta', 0.0, 1.0)
336+
self.validate_non_negative(weight_decay, 'weight_decay')
337+
338+
defaults: DEFAULTS = {
339+
'lr': lr,
340+
'beta': beta,
341+
'weight_decay': weight_decay,
342+
'weight_decouple': weight_decouple,
343+
}
344+
345+
super().__init__(params, defaults)
346+
347+
@torch.no_grad()
348+
def reset(self):
349+
for group in self.param_groups:
350+
for p in group['params']:
351+
state = self.state[p]
352+
353+
if group['beta'] > 0.0:
354+
state['momentum_buffer'] = p.grad.clone()
355+
356+
@torch.no_grad()
357+
def step(self, closure: CLOSURE = None) -> LOSS:
358+
loss: LOSS = None
359+
if closure is not None:
360+
with torch.enable_grad():
361+
loss = closure()
362+
363+
for group in self.param_groups:
364+
beta = group['beta']
365+
for p in group['params']:
366+
if p.grad is None:
367+
continue
368+
369+
if p.grad.is_sparse:
370+
raise NoSparseGradientError(str(self))
371+
372+
state = self.state[p]
373+
374+
if beta > 0.0:
375+
if len(state) == 0:
376+
state['momentum_buffer'] = p.grad.clone()
377+
378+
buf = state['momentum_buffer']
379+
buf.mul_(beta).add_(p.grad, alpha=1.0 - beta)
380+
else:
381+
buf = p.grad
382+
383+
p.add_(torch.sign(buf), alpha=-group['lr'])
384+
385+
return loss

pytorch_optimizer/optimizer/sophiah.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
55
from pytorch_optimizer.base.optimizer import BaseOptimizer
6-
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, HUTCHINSON_G
77

88

99
class SophiaH(Optimizer, BaseOptimizer):
@@ -33,6 +33,7 @@ def __init__(self,
3333
p: float = 25.,
3434
update_period: int = 10,
3535
n_samples: int = 1,
36+
hessian_distribution: HUTCHINSON_G = 'gaussian',
3637
eps: float = 1e-12):
3738

3839
self.validate_learning_rate(lr)

0 commit comments

Comments
 (0)