Skip to content

Commit 42655e0

Browse files
authored
Merge pull request #42 from kozistr/feature/load-optimizers
[Test] Add test case for load_optimizers
2 parents 6343cca + b974419 commit 42655e0

File tree

6 files changed

+64
-18
lines changed

6 files changed

+64
-18
lines changed

pytorch_optimizer/adabound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
129129

130130
beta1, beta2 = group['betas']
131131

132-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
133-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
132+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
133+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
134134
if group['amsbound']:
135135
max_exp_avg_sq = torch.max(max_exp_avg_sq, exp_avg_sq)
136136
denom = max_exp_avg_sq.sqrt().add_(group['eps'])

pytorch_optimizer/diffgrad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9090
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']
9191

9292
if group['weight_decay'] != 0:
93-
grad.add_(group['weight_decay'], p.data)
93+
grad.add_(p.data, alpha=group['weight_decay'])
9494

9595
state['step'] += 1
9696
beta1, beta2 = group['betas']
9797

9898
# Decay the first and second moment running average coefficient
99-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
100-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
99+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
100+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
101101
denom = exp_avg_sq.sqrt().add_(group['eps'])
102102

103103
bias_correction1 = 1 - beta1 ** state['step']
@@ -116,6 +116,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
116116
else:
117117
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
118118

119-
p.data.addcdiv_(-step_size, exp_avg1, denom)
119+
p.data.addcdiv_(exp_avg1, denom, value=-step_size)
120120

121121
return loss

pytorch_optimizer/diffrgrad.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
121121

122122
bias_correction1 = 1 - beta1 ** state['step']
123123

124-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
125-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
124+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
125+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
126126

127127
# compute diffGrad coefficient (dfc)
128128
diff = abs(previous_grad - grad)
@@ -164,18 +164,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
164164

165165
if n_sma >= self.n_sma_threshold:
166166
if group['weight_decay'] != 0:
167-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
167+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
168168

169169
denom = exp_avg_sq.sqrt().add_(group['eps'])
170170

171171
# update momentum with dfc
172-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
172+
p_data_fp32.addcdiv_(exp_avg * dfc.float(), denom, value=-step_size * group['lr'])
173173
p.data.copy_(p_data_fp32)
174174
elif step_size > 0:
175175
if group['weight_decay'] != 0:
176-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
176+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
177177

178-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
178+
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
179179
p.data.copy_(p_data_fp32)
180180

181181
return loss

pytorch_optimizer/optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytorch_optimizer.diffgrad import DiffGrad
66
from pytorch_optimizer.diffrgrad import DiffRGrad
77
from pytorch_optimizer.fp16 import SafeFP16Optimizer
8+
from pytorch_optimizer.lamb import Lamb
89
from pytorch_optimizer.madgrad import MADGRAD
910
from pytorch_optimizer.radam import RAdam
1011
from pytorch_optimizer.ranger import Ranger
@@ -39,6 +40,8 @@ def load_optimizers(optimizer: str, use_fp16: bool = False):
3940
opt = DiffGrad
4041
elif optimizer == 'adahessian':
4142
opt = AdaHessian
43+
elif optimizer == 'lamb':
44+
opt = Lamb
4245
else:
4346
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
4447

pytorch_optimizer/radam.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118118

119119
bias_correction1 = 1 - beta1 ** state['step']
120120

121-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
122-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
121+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
122+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
123123

124124
state['step'] += 1
125125
buffered = group['buffer'][int(state['step'] % 10)]
@@ -155,14 +155,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
155155

156156
if n_sma >= self.n_sma_threshold:
157157
if group['weight_decay'] != 0:
158-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
158+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
159159
denom = exp_avg_sq.sqrt().add_(group['eps'])
160-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
160+
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
161161
p.data.copy_(p_data_fp32)
162162
elif step_size > 0:
163163
if group['weight_decay'] != 0:
164-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
165-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
164+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
165+
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
166166
p.data.copy_(p_data_fp32)
167167

168168
return loss

tests/test_load_optimizers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from pytorch_optimizer import load_optimizers
6+
7+
VALID_OPTIMIZER_NAMES: List[str] = [
8+
'adamp',
9+
'sgdp',
10+
'madgrad',
11+
'ranger',
12+
'ranger21',
13+
'radam',
14+
'adabound',
15+
'adahessian',
16+
'adabelief',
17+
'diffgrad',
18+
'diffrgrad',
19+
'lamb',
20+
]
21+
22+
INVALID_OPTIMIZER_NAMES: List[str] = [
23+
'asam',
24+
'sam',
25+
'pcgrad',
26+
'adamd',
27+
'lookahead',
28+
'chebyshev_schedule',
29+
]
30+
31+
32+
@pytest.mark.parametrize('valid_optimizer_names', VALID_OPTIMIZER_NAMES)
33+
def test_load_optimizers_valid(valid_optimizer_names):
34+
load_optimizers(valid_optimizer_names)
35+
36+
37+
@pytest.mark.parametrize('invalid_optimizer_names', INVALID_OPTIMIZER_NAMES)
38+
def test_load_optimizers_invalid(invalid_optimizer_names):
39+
try:
40+
load_optimizers(invalid_optimizer_names)
41+
except NotImplementedError:
42+
return True
43+
return False

0 commit comments

Comments
 (0)