Skip to content

Commit d0a9b1b

Browse files
author
ferris
committed
Pass all tests
1 parent 2ed564d commit d0a9b1b

File tree

7 files changed

+50
-13
lines changed

7 files changed

+50
-13
lines changed

pytorch_optimizer/base/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
7+
from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError, NoSparseGradientError
88
from pytorch_optimizer.base.types import BETAS, HUTCHINSON_G
99

1010

@@ -48,7 +48,9 @@ def compute_hutchinson_hessian(self, nsamples: int = 1, pre_zero=True, alpha=1.0
4848
params = []
4949
for group in self.param_groups:
5050
for p in group['params']:
51-
if p.grad is not None:
51+
if p.requires_grad and p.grad is not None:
52+
if p.grad.is_sparse:
53+
raise NoSparseGradientError(str(self))
5254
# Initialize Hessian state
5355
if 'hessian' in self.state[p]:
5456
if pre_zero:

pytorch_optimizer/optimizer/adahessian.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ def reset(self):
6969
state['exp_hessian_diag_sq'] = torch.zero_like(p)
7070

7171
@torch.no_grad()
72-
def step(self, closure: CLOSURE = None) -> LOSS:
72+
def step(self, closure: CLOSURE = None, hessian: tuple[torch.Tensor] = None) -> LOSS:
7373
loss: LOSS = None
7474
if closure is not None:
7575
with torch.enable_grad():
7676
loss = closure()
7777

78-
if self._step % self.update_period == 0:
78+
if hessian is not None:
79+
self.set_hessian(hessian)
80+
elif self._step % self.update_period == 0:
7981
self.compute_hutchinson_hessian(self.n_samples, distribution=self.distribution)
8082

8183
for group in self.param_groups:
@@ -107,7 +109,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107109

108110
# Decay the first and second moment running average coefficient
109111
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
110-
if self._step % self.update_period == 0:
112+
if (self._step % self.update_period == 0 or hessian is not None) and 'hessian' in state:
111113
# if self.average_conv_kernel and p.dim() == 4:
112114
# state['hessian'] = torch.abs(state['hessian']).mean(dim=[2, 3], keepdim=True).expand_as(state['hessian']).clone()
113115
exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1 - beta2)

pytorch_optimizer/optimizer/sophiah.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def step(self, closure: CLOSURE = None, hessian: tuple[torch.Tensor] = None) ->
8383
if p.grad is None:
8484
continue
8585

86+
state = self.state[p]
8687
grad = p.grad
8788
if grad.is_sparse:
8889
raise NoSparseGradientError(str(self))
8990

9091
# State initialization
91-
state = self.state[p]
9292
if 'momentum' not in state:
9393
state['momentum'] = torch.zeros_like(p)
9494
state['hessian_moment'] = torch.zeros_like(p)
@@ -106,7 +106,7 @@ def step(self, closure: CLOSURE = None, hessian: tuple[torch.Tensor] = None) ->
106106
momentum, hessian_moment = state['momentum'], state['hessian_moment']
107107

108108
momentum.mul_(beta1).add_(p.grad, alpha=1.0-beta1)
109-
if self._step % self.update_period == 0 or hessian is not None:
109+
if (self._step % self.update_period == 0 or hessian is not None) and 'hessian' in state:
110110
hessian_moment.mul_(beta2).add_(state['hessian'], alpha=1.0-beta2)
111111

112112
# See https://shreyansh26.github.io/post/2023-05-28_sophia_scalable_second_order_optimizer_llms/#per-coordinate-clipping

tests/test_general_optimizer_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_epsilon(optimizer_name):
4141
'alig',
4242
'gravity',
4343
'srmm',
44+
'signsgd'
4445
):
4546
pytest.skip(f'skip {optimizer_name} optimizer')
4647

tests/test_gradients.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@ def test_no_gradients(optimizer_name):
2424
else:
2525
optimizer = load_optimizer(optimizer_name)(params)
2626

27+
def sphere_loss(x) -> torch.Tensor:
28+
return (x ** 2).sum()
29+
2730
optimizer.zero_grad()
28-
p1.grad = torch.zeros(1, 1)
29-
p2.grad = None
30-
p3.grad = torch.zeros(1, 1)
31-
p4.grad = None
31+
sphere_loss(p1 + p3).backward(create_graph=True)
32+
# p1.grad = torch.zeros(1, 1)
33+
# p2.grad = None
34+
# p3.grad = torch.zeros(1, 1)
35+
# p4.grad = None
3236
optimizer.step(lambda: 0.1) # for AliG optimizer
37+
if optimizer_name != 'lookahead':
38+
optimizer.zero_grad(set_to_none=True)
3339

3440

3541
@pytest.mark.parametrize('no_sparse_optimizer', NO_SPARSE_OPTIMIZERS)
@@ -109,12 +115,17 @@ def test_bf16_gradient(optimizer_name):
109115
if optimizer_name == 'shampoo':
110116
pytest.skip(f'skip {optimizer_name}')
111117

118+
def sphere_loss(x) -> torch.Tensor:
119+
return (x ** 2).sum()
120+
112121
param = torch.randn(1, 1).bfloat16().requires_grad_(True)
113-
param.grad = torch.randn(1, 1).bfloat16()
114122

115123
opt = load_optimizer(optimizer=optimizer_name)
116124
optimizer = opt([param], num_iterations=1) if optimizer_name == 'ranger21' else opt([param])
125+
126+
sphere_loss(param).backward(create_graph=True)
117127
optimizer.step(lambda: 0.1)
128+
optimizer.zero_grad(True)
118129

119130

120131
def test_sam_no_gradient():

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 51
19+
assert len(get_supported_optimizers()) == 54

tests/test_optimizers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,27 @@ def test_rectified_optimizer(optimizer_name):
322322
optimizer.step()
323323

324324

325+
@pytest.mark.parametrize('optimizer_name', ['sophiah', 'adahessian'])
326+
def test_hessian_optimizer(optimizer_name):
327+
param = simple_parameter()
328+
329+
def sphere_loss(x) -> torch.Tensor:
330+
return (x ** 2).sum()
331+
332+
parameters = {'hessian_distribution': 'gaussian', 'n_samples': 2}
333+
optimizer = load_optimizer(optimizer_name)([param], **parameters)
334+
optimizer.zero_grad(set_to_none=True)
335+
336+
# Hutchinson (internal) estimator
337+
sphere_loss(param).backward(create_graph=True)
338+
optimizer.step()
339+
optimizer.zero_grad(set_to_none=True)
340+
341+
# External estimator
342+
sphere_loss(param).backward()
343+
optimizer.step(hessian=torch.zeros_like(param).unsqueeze(0))
344+
345+
325346
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS + ADANORM_SUPPORTED_OPTIMIZERS, ids=ids)
326347
def test_reset(optimizer_config):
327348
optimizer_class, config, _ = optimizer_config

0 commit comments

Comments
 (0)