Skip to content

Commit 8439f15

Browse files
authored
Merge pull request #51 from kozistr/feature/improve-perf
[Feature] Improve overall performance of the optimizers
2 parents 3cd5158 + 658d3a9 commit 8439f15

23 files changed

+477
-387
lines changed

Pipfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ verify_ssl = false
66
[dev-packages]
77
isort = "==5.10.1"
88
black = "==21.12b0"
9-
pylint = "==3.0.0a4"
10-
pytest = "==6.2.5"
9+
pylint = "==2.11.1"
10+
pytest = "==7.0.1"
1111
pytest-cov = "==3.0.0"
1212

1313
[packages]

Pipfile.lock

Lines changed: 119 additions & 72 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/adabelief.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import torch
44
from torch.optim.optimizer import Optimizer
55

6-
from pytorch_optimizer.types import BETAS, BUFFER, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7-
from pytorch_optimizer.utils import is_valid_parameters
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
87

98

109
class AdaBelief(Optimizer):
@@ -65,21 +64,14 @@ def __init__(
6564

6665
self.check_valid_parameters()
6766

68-
buffer: BUFFER = [[None, None, None] for _ in range(10)]
69-
70-
if is_valid_parameters(params):
71-
for param in params:
72-
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
73-
param['buffer'] = buffer
74-
7567
defaults: DEFAULTS = dict(
7668
lr=lr,
7769
betas=betas,
7870
eps=eps,
7971
weight_decay=weight_decay,
8072
amsgrad=amsgrad,
8173
adamd_debias_term=adamd_debias_term,
82-
buffer=buffer,
74+
buffer=[[None, None, None] for _ in range(10)],
8375
)
8476
super().__init__(params, defaults)
8577

@@ -101,53 +93,57 @@ def __setstate__(self, state: STATE):
10193
group.setdefault('amsgrad', False)
10294
group.setdefault('adamd_debias_term', False)
10395

96+
@torch.no_grad()
10497
def reset(self):
10598
for group in self.param_groups:
10699
for p in group['params']:
107100
state = self.state[p]
108101

109102
state['step'] = 0
110-
state['exp_avg'] = torch.zeros_like(p.data)
111-
state['exp_avg_var'] = torch.zeros_like(p.data)
103+
state['exp_avg'] = torch.zeros_like(p)
104+
state['exp_avg_var'] = torch.zeros_like(p)
112105
if group['amsgrad']:
113-
state['max_exp_avg_var'] = torch.zeros_like(p.data)
106+
state['max_exp_avg_var'] = torch.zeros_like(p)
114107

108+
@torch.no_grad()
115109
def step(self, closure: CLOSURE = None) -> LOSS:
116110
loss: LOSS = None
117111
if closure is not None:
118-
loss = closure()
112+
with torch.enable_grad():
113+
loss = closure()
119114

120115
for group in self.param_groups:
121116
for p in group['params']:
122117
if p.grad is None:
123118
continue
124119

125-
half_precision: bool = False
126-
if p.data.dtype == torch.float16:
127-
half_precision = True
128-
p.data = p.data.float()
129-
p.grad = p.grad.float()
130-
131-
grad = p.grad.data
120+
grad = p.grad
132121
if grad.is_sparse:
133122
raise RuntimeError('AdaBelief does not support sparse gradients')
134123

124+
if grad.dtype in (torch.float16, torch.bfloat16):
125+
grad = grad.float()
126+
127+
p_fp32 = p
128+
if p.dtype in {torch.float16, torch.bfloat16}:
129+
p_fp32 = p_fp32.float()
130+
135131
state = self.state[p]
136132
if len(state) == 0:
137133
state['step'] = 0
138-
state['exp_avg'] = torch.zeros_like(p.data)
139-
state['exp_avg_var'] = torch.zeros_like(p.data)
134+
state['exp_avg'] = torch.zeros_like(p)
135+
state['exp_avg_var'] = torch.zeros_like(p)
140136
if group['amsgrad']:
141-
state['max_exp_avg_var'] = torch.zeros_like(p.data)
137+
state['max_exp_avg_var'] = torch.zeros_like(p)
142138

143139
if self.weight_decouple:
144140
if not self.fixed_decay:
145-
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
141+
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
146142
else:
147-
p.data.mul_(1.0 - group['weight_decay'])
143+
p_fp32.mul_(1.0 - group['weight_decay'])
148144
else:
149145
if group['weight_decay'] != 0:
150-
grad.add_(p.data, alpha=group['weight_decay'])
146+
grad.add_(p_fp32, alpha=group['weight_decay'])
151147

152148
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
153149

@@ -170,15 +166,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170166
out=max_exp_avg_var,
171167
)
172168

173-
denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
169+
de_nom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
174170
else:
175-
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
171+
de_nom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
176172

177173
if not self.rectify:
178174
step_size = group['lr']
179175
if not group['adamd_debias_term']:
180176
step_size /= bias_correction1
181-
p.data.addcdiv_(exp_avg, denom, value=-step_size)
177+
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size)
182178
else:
183179
buffered = group['buffer'][int(state['step'] % 10)]
184180
if state['step'] == buffered[0]:
@@ -212,13 +208,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
212208
buffered[2] = step_size
213209

214210
if n_sma >= self.n_sma_threshold:
215-
denom = exp_avg_var.sqrt().add_(group['eps'])
216-
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
211+
de_nom = exp_avg_var.sqrt().add_(group['eps'])
212+
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
217213
elif step_size > 0:
218-
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
214+
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
219215

220-
if half_precision:
221-
p.data = p.data.half()
222-
p.grad = p.grad.half()
216+
if p.dtype in {torch.float16, torch.bfloat16}:
217+
p.copy_(p_fp32)
223218

224219
return loss

pytorch_optimizer/adabound.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,19 @@ def __setstate__(self, state: STATE):
9191
group.setdefault('amsbound', False)
9292
group.setdefault('adamd_debias_term', False)
9393

94+
@torch.no_grad()
9495
def step(self, closure: CLOSURE = None) -> LOSS:
9596
loss: LOSS = None
9697
if closure is not None:
97-
loss = closure()
98+
with torch.enable_grad():
99+
loss = closure()
98100

99101
for group, base_lr in zip(self.param_groups, self.base_lrs):
100102
for p in group['params']:
101103
if p.grad is None:
102104
continue
103105

104-
grad = p.grad.data
106+
grad = p.grad
105107
if grad.is_sparse:
106108
raise RuntimeError('AdaBound does not support sparse gradients')
107109

@@ -114,46 +116,42 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114116
if group['amsbound']:
115117
state['max_exp_avg_sq'] = torch.zeros_like(p)
116118

117-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
118-
if group['amsbound']:
119-
max_exp_avg_sq = state['max_exp_avg_sq']
120-
121119
state['step'] += 1
120+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
122121

123122
if self.weight_decouple:
124123
if not self.fixed_decay:
125-
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
124+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
126125
else:
127-
p.data.mul_(1.0 - group['weight_decay'])
126+
p.mul_(1.0 - group['weight_decay'])
128127
else:
129128
if group['weight_decay'] != 0:
130-
grad.add_(p.data, alpha=group['weight_decay'])
129+
grad.add_(p, alpha=group['weight_decay'])
131130

132131
beta1, beta2 = group['betas']
133132

134-
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
135-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
133+
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
134+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
136135
if group['amsbound']:
137-
max_exp_avg_sq = torch.max(max_exp_avg_sq, exp_avg_sq)
138-
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
136+
max_exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
137+
de_nom = max_exp_avg_sq.sqrt().add_(group['eps'])
139138
else:
140-
denom = exp_avg_sq.sqrt().add_(group['eps'])
139+
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
141140

142141
bias_correction1 = 1 - beta1 ** state['step']
143142
bias_correction2 = 1 - beta2 ** state['step']
144143

145-
if group['adamd_debias_term']:
146-
step_size = group['lr'] * math.sqrt(bias_correction2)
147-
else:
148-
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
144+
step_size = group['lr'] * math.sqrt(bias_correction2)
145+
if not group['adamd_debias_term']:
146+
step_size /= bias_correction1
149147

150148
final_lr = group['final_lr'] * group['lr'] / base_lr
151149
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
152150
upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
153151

154-
step_size = torch.full_like(denom, step_size)
155-
step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
152+
step_size = torch.full_like(de_nom, step_size)
153+
step_size.div_(de_nom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
156154

157-
p.data.add_(-step_size)
155+
p.add_(-step_size)
158156

159157
return loss

pytorch_optimizer/adahessian.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,12 @@ def set_hessian(self):
142142
# approximate the expected values of z * (H@z)
143143
p.hess += h_z * z / self.num_samples
144144

145+
@torch.no_grad()
145146
def step(self, closure: CLOSURE = None) -> LOSS:
146147
loss: LOSS = None
147148
if closure is not None:
148-
loss = closure()
149+
with torch.enable_grad():
150+
loss = closure()
149151

150152
self.zero_hessian()
151153
self.set_hessian()
@@ -164,8 +166,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
164166
state = self.state[p]
165167
if len(state) == 1:
166168
state['step'] = 0
167-
state['exp_avg'] = torch.zeros_like(p.data)
168-
state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)
169+
state['exp_avg'] = torch.zeros_like(p)
170+
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
169171

170172
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
171173

@@ -180,13 +182,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
180182
bias_correction2 = 1 - beta2 ** state['step']
181183

182184
hessian_power = group['hessian_power']
183-
denom = (exp_hessian_diag_sq / bias_correction2).pow_(hessian_power / 2).add_(group['eps'])
185+
de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(hessian_power / 2.0).add_(group['eps'])
184186

185-
if group['adamd_debias_term']:
186-
step_size = group['lr']
187-
else:
188-
step_size = group['lr'] / bias_correction1
187+
step_size = group['lr']
188+
if not group['adamd_debias_term']:
189+
step_size /= bias_correction1
189190

190-
p.addcdiv_(exp_avg, denom, value=-step_size)
191+
p.addcdiv_(exp_avg, de_nom, value=-step_size)
191192

192193
return loss

pytorch_optimizer/adamp.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -117,58 +117,58 @@ def projection(
117117
wd: float = 1.0
118118
expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
119119
for view_func in (self.channel_view, self.layer_view):
120-
cosine_sim = self.cosine_similarity(grad, p.data, eps, view_func)
120+
cosine_sim = self.cosine_similarity(grad, p, eps, view_func)
121121

122-
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size()[1]):
123-
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
122+
if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
123+
p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
124124
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
125125
wd = wd_ratio
126126
return perturb, wd
127127

128128
return perturb, wd
129129

130+
@torch.no_grad()
130131
def step(self, closure: CLOSURE = None) -> LOSS:
131132
loss: LOSS = None
132133
if closure is not None:
133-
loss = closure()
134+
with torch.enable_grad():
135+
loss = closure()
134136

135137
for group in self.param_groups:
136138
for p in group['params']:
137139
if p.grad is None:
138140
continue
139141

142+
grad = p.grad
143+
if grad.is_sparse:
144+
raise RuntimeError('AdamP does not support sparse gradients')
145+
140146
state = self.state[p]
141147
if len(state) == 0:
142148
state['step'] = 0
143-
state['exp_avg'] = torch.zeros_like(p.data)
144-
state['exp_avg_sq'] = torch.zeros_like(p.data)
149+
state['exp_avg'] = torch.zeros_like(p)
150+
state['exp_avg_sq'] = torch.zeros_like(p)
145151

146152
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
147153

148154
state['step'] += 1
149155
beta1, beta2 = group['betas']
150156

151-
bias_correction1 = 1 - beta1 ** state['step']
152-
bias_correction2 = 1 - beta2 ** state['step']
153-
154-
grad = p.grad.data
157+
bias_correction1 = 1.0 - beta1 ** state['step']
158+
bias_correction2 = 1.0 - beta2 ** state['step']
155159

156160
if self.use_gc:
157161
grad = centralize_gradient(grad, gc_conv_only=False)
158162

159163
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
160-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
164+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
161165

162-
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
163-
if group['adamd_debias_term']:
164-
step_size = group['lr']
165-
else:
166-
step_size = group['lr'] / bias_correction1
166+
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
167167

168168
if group['nesterov']:
169-
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
169+
perturb = (beta1 * exp_avg + (1.0 - beta1) * grad) / de_nom
170170
else:
171-
perturb = exp_avg / denom
171+
perturb = exp_avg / de_nom
172172

173173
wd_ratio: float = 1
174174
if len(p.shape) > 1:
@@ -182,8 +182,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
182182
)
183183

184184
if group['weight_decay'] > 0:
185-
p.data.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)
185+
p.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)
186+
187+
step_size = group['lr']
188+
if not group['adamd_debias_term']:
189+
step_size /= bias_correction1
186190

187-
p.data.add_(perturb, alpha=-step_size)
191+
p.add_(perturb, alpha=-step_size)
188192

189193
return loss

0 commit comments

Comments
 (0)