Skip to content

Commit 5f1ef59

Browse files
authored
Merge pull request #44 from kozistr/fix/adahessian
[Test] Add FP16 & SAM test cases
2 parents 75463dc + e874336 commit 5f1ef59

File tree

7 files changed

+107
-37
lines changed

7 files changed

+107
-37
lines changed

pytorch_optimizer/adahessian.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ def __init__(
3434
average_conv_kernel: bool = False,
3535
adamd_debias_term: bool = False,
3636
eps: float = 1e-8,
37-
seed: int = 2147483647,
37+
seed: int = 1337,
3838
):
39-
"""
39+
"""AdaHessian
4040
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41-
:param lr: float. learning rate.
41+
:param lr: float. learning rate
4242
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4343
:param weight_decay: float. weight decay (L2 penalty)
4444
:param hessian_power: float. exponent of the hessian trace
4545
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
4646
:param num_samples: int. how many times to sample `z` for the approximation of the hessian trace
47-
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
47+
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper
4848
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4949
:param eps: float. term added to the denominator to improve numerical stability
5050
:param seed: int.
@@ -103,16 +103,17 @@ def zero_hessian(self):
103103
if not isinstance(p.hess, float) and self.state[p]['hessian_step'] % self.update_each == 0:
104104
p.hess.zero_()
105105

106-
@torch.no_grad()
107106
def set_hessian(self):
108-
"""Computes the Hutchinson approximation of the hessian trace
109-
and accumulates it for each trainable parameter
110-
"""
107+
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
111108
params = []
112-
for p in filter(lambda param: param.grad is not None, self.get_params()):
109+
for p in self.get_params():
110+
if p.grad is None:
111+
continue
112+
113113
# compute the trace only each `update_each` step
114114
if self.state[p]['hessian_step'] % self.update_each == 0:
115115
params.append(p)
116+
116117
self.state[p]['hessian_step'] += 1
117118

118119
if len(params) == 0:
@@ -126,7 +127,7 @@ def set_hessian(self):
126127

127128
for i in range(self.num_samples):
128129
# Rademacher distribution {-1.0, 1.0}
129-
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
130+
zs = [2.0 * torch.randint(0, 2, p.size()).float().requires_grad_(True) - 1.0 for p in params]
130131

131132
# note that, possible memory leak due to retrain_graph=True
132133
h_zs = torch.autograd.grad(
@@ -141,7 +142,6 @@ def set_hessian(self):
141142
# approximate the expected values of z * (H@z)
142143
p.hess += h_z * z / self.num_samples
143144

144-
@torch.no_grad()
145145
def step(self, closure: CLOSURE = None) -> LOSS:
146146
loss: LOSS = None
147147
if closure is not None:
@@ -156,7 +156,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
156156
continue
157157

158158
if self.average_conv_kernel and p.dim() == 4:
159-
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
159+
p.hess = torch.abs(p.hess).mean(dim=(2, 3), keepdim=True).expand_as(p.hess).clone()
160160

161161
# Perform correct step-weight decay as in AdamW
162162
p.mul_(1.0 - group['lr'] * group['weight_decay'])

pytorch_optimizer/agc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55

66
def agc(p: torch.Tensor, agc_eps: float, agc_clip_val: float, eps: float = 1e-6):
7-
"""Clip gradient values in excess of the unit-wise norm.
8-
:param p: parameter.
9-
:param agc_eps: float.
10-
:param agc_clip_val: float.
7+
"""Clip gradient values in excess of the unit-wise norm
8+
:param p: parameter. parameter
9+
:param agc_eps: float. epsilon
10+
:param agc_clip_val: float. norm clip
1111
:param eps: float. simple stop from div by zero and no relation to standard optimizer eps
1212
"""
1313
p_norm = unit_norm(p).clamp_(agc_eps)

pytorch_optimizer/gc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
def centralize_gradient(x: torch.Tensor, gc_conv_only: bool = False) -> torch.Tensor:
55
"""Gradient Centralization (GC)
6-
:param x: torch.Tensor. gradient.
7-
:param gc_conv_only: bool. 'False' for both conv & fc layers.
6+
:param x: torch.Tensor. gradient
7+
:param gc_conv_only: bool. 'False' for both conv & fc layers
88
:return: torch.Tensor. GC-ed gradient
99
"""
1010
size: int = x.dim()

pytorch_optimizer/madgrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
):
3939
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
4040
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
41-
:param lr: float. learning rate.
41+
:param lr: float. learning rate
4242
:param eps: float. term added to the denominator to improve numerical stability
4343
:param weight_decay: float. weight decay (L2 penalty)
4444
MADGRAD optimizer requires less weight decay than other methods, often as little as zero

pytorch_optimizer/ranger21.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
__AUTHORS__ = [
2-
'@lessw2020',
3-
'@NestorDemeure',
4-
# with contributions from :
5-
'@BrianPugh',
6-
'@Kayuksel',
7-
'@TheZothen',
8-
]
9-
101
import math
112
from typing import Optional
123

@@ -19,6 +10,15 @@
1910
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
2011
from pytorch_optimizer.utils import normalize_gradient, unit_norm
2112

13+
__AUTHORS__ = [
14+
'@lessw2020',
15+
'@NestorDemeure',
16+
# with contributions from :
17+
'@BrianPugh',
18+
'@Kayuksel',
19+
'@TheZothen',
20+
]
21+
2222

2323
class Ranger21(Optimizer):
2424
"""
@@ -185,7 +185,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
185185
param_size: int = 0
186186
variance_ma_sum: float = 1.0
187187

188-
# Phase 1 - Accumulate all of the variance_ma_sum to use in stable weight decay
188+
# Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
189189
for group in self.param_groups:
190190
for p in group['params']:
191191
if p.grad is None:

pytorch_optimizer/sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def __init__(
5757
adaptive: bool = False,
5858
**kwargs,
5959
):
60-
"""
60+
"""SAM
6161
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
62-
:param base_optimizer: Optimizer.
62+
:param base_optimizer: Optimizer. base optimizer
6363
:param rho: float. size of the neighborhood for computing the max loss
6464
:param adaptive: bool. element-wise Adaptive SAM
6565
:param kwargs: Dict. parameters for optimizer.

tests/test_optimizers.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pytorch_optimizer import (
1010
MADGRAD,
11+
SAM,
1112
SGDP,
1213
AdaBelief,
1314
AdaBound,
@@ -19,6 +20,7 @@
1920
RAdam,
2021
Ranger,
2122
Ranger21,
23+
SafeFP16Optimizer,
2224
)
2325

2426
__REFERENCE__ = 'https://github.com/jettify/pytorch-optimizer/blob/master/tests/test_optimizer_with_nn.py'
@@ -66,7 +68,7 @@ def build_lookahead(*parameters, **kwargs):
6668
return Lookahead(AdamP(*parameters, **kwargs))
6769

6870

69-
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
71+
FP32_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
7072
(build_lookahead, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
7173
(AdaBelief, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
7274
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
@@ -78,21 +80,34 @@ def build_lookahead(*parameters, **kwargs):
7880
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
7981
(SGDP, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
8082
(Ranger, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
81-
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 1000}, 500),
82-
# (AdaHessian, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
83+
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
8384
]
8485

86+
FP16_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
87+
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
88+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
89+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
90+
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
91+
(DiffGrad, {'lr': 15 - 1, 'weight_decay': 1e-3}, 500),
92+
(DiffRGrad, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
93+
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
94+
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 200),
95+
(SGDP, {'lr': 5e-1, 'weight_decay': 1e-3}, 500),
96+
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
97+
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
98+
]
8599

86-
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)
87-
def test_optimizers(optimizer_config):
100+
101+
@pytest.mark.parametrize('optimizer_fp32_config', FP32_OPTIMIZERS, ids=ids)
102+
def test_f32_optimizers(optimizer_fp32_config):
88103
torch.manual_seed(42)
89104

90105
x_data, y_data = make_dataset()
91106

92107
model: nn.Module = LogisticRegression()
93108
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
94109

95-
optimizer_class, config, iterations = optimizer_config
110+
optimizer_class, config, iterations = optimizer_fp32_config
96111
optimizer = optimizer_class(model.parameters(), **config)
97112

98113
loss: float = np.inf
@@ -111,3 +126,58 @@ def test_optimizers(optimizer_config):
111126
optimizer.step()
112127

113128
assert init_loss > 2.0 * loss
129+
130+
131+
@pytest.mark.parametrize('optimizer_fp16_config', FP16_OPTIMIZERS, ids=ids)
132+
def test_f16_optimizers(optimizer_fp16_config):
133+
torch.manual_seed(42)
134+
135+
x_data, y_data = make_dataset()
136+
137+
model: nn.Module = LogisticRegression()
138+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
139+
140+
optimizer_class, config, iterations = optimizer_fp16_config
141+
optimizer = SafeFP16Optimizer(optimizer_class(model.parameters(), **config))
142+
143+
loss: float = np.inf
144+
init_loss: float = np.inf
145+
for _ in range(1000):
146+
optimizer.zero_grad()
147+
148+
y_pred = model(x_data)
149+
loss = loss_fn(y_pred, y_data)
150+
151+
if init_loss == np.inf:
152+
init_loss = loss
153+
154+
loss.backward()
155+
156+
optimizer.step()
157+
158+
assert init_loss - 0.01 > loss
159+
160+
161+
@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
162+
def test_sam_optimizers(optimizer_config):
163+
torch.manual_seed(42)
164+
165+
x_data, y_data = make_dataset()
166+
167+
model: nn.Module = LogisticRegression()
168+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
169+
170+
optimizer_class, config, iterations = optimizer_config
171+
optimizer = SAM(model.parameters(), optimizer_class, **config)
172+
173+
loss: float = np.inf
174+
init_loss: float = np.inf
175+
for _ in range(iterations):
176+
loss = loss_fn(y_data, model(x_data))
177+
loss.backward()
178+
optimizer.first_step(zero_grad=True)
179+
180+
loss_fn(y_data, model(x_data)).backward()
181+
optimizer.second_step(zero_grad=True)
182+
183+
assert init_loss > 2.0 * loss

0 commit comments

Comments
 (0)