Skip to content

Commit 0567ae9

Browse files
authored
Merge pull request #96 from kozistr/refactor/optimizers
[Refactor] Cleanup the codes
2 parents ce56167 + c155279 commit 0567ae9

File tree

16 files changed

+141
-207
lines changed

16 files changed

+141
-207
lines changed

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 48 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9595

9696
for group in self.param_groups:
9797
beta1, beta2 = group['betas']
98+
weight_decay: float = group['weight_decay']
99+
98100
if self.rectify:
99101
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
100102

@@ -106,13 +108,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
106108
if grad.is_sparse:
107109
raise NoSparseGradientError(self.__name__)
108110

109-
if grad.dtype in (torch.float16, torch.bfloat16):
110-
grad = grad.float()
111-
112-
p_fp32 = p
113-
if p.dtype in (torch.float16, torch.bfloat16):
114-
p_fp32 = p_fp32.float()
115-
116111
state = self.state[p]
117112
if len(state) == 0:
118113
state['step'] = 0
@@ -122,70 +117,65 @@ def step(self, closure: CLOSURE = None) -> LOSS:
122117
state['max_exp_avg_var'] = torch.zeros_like(p)
123118

124119
if self.weight_decouple:
125-
decay: float = (
126-
group['lr'] * group['weight_decay'] if not self.fixed_decay else group['weight_decay']
127-
)
128-
p_fp32.mul_(1.0 - decay)
129-
elif group['weight_decay'] != 0:
130-
grad.add_(p_fp32, alpha=group['weight_decay'])
131-
132-
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
120+
p.mul_(1.0 - (group['lr'] * weight_decay if not self.fixed_decay else weight_decay))
121+
elif weight_decay > 0.0:
122+
grad.add_(p, alpha=weight_decay)
133123

134124
state['step'] += 1
125+
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
135126

136127
bias_correction1 = 1.0 - beta1 ** state['step']
137-
bias_correction2 = 1.0 - beta2 ** state['step']
128+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** state['step'])
138129

139130
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
140131
grad_residual = grad - exp_avg
141-
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
142-
exp_avg_var.add_(group['eps'])
143-
if group['amsgrad']:
144-
torch.max(state['max_exp_avg_var'], exp_avg_var, out=exp_avg_var)
132+
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])
145133

146-
de_nom = (exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
134+
if group['amsgrad']:
135+
max_exp_avg_var = state['max_exp_avg_var']
136+
torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var)
137+
de_nom = max_exp_avg_var.sqrt()
138+
else:
139+
de_nom = exp_avg_var.sqrt()
140+
de_nom.div_(bias_correction2_sq).add_(group['eps'])
147141

148142
if not self.rectify:
149-
step_size = group['lr']
150-
if not group['adamd_debias_term']:
151-
step_size /= bias_correction1
152-
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size)
143+
step_size: float = group['lr'] if group['adamd_debias_term'] else group['lr'] / bias_correction1
144+
p.addcdiv_(exp_avg, de_nom, value=-step_size)
145+
continue
146+
147+
buffered = group['buffer'][state['step'] % 10]
148+
if state['step'] == buffered[0]:
149+
n_sma, step_size = buffered[1], buffered[2]
153150
else:
154-
buffered = group['buffer'][state['step'] % 10]
155-
if state['step'] == buffered[0]:
156-
n_sma, step_size = buffered[1], buffered[2]
157-
else:
158-
buffered[0] = state['step']
159-
beta2_t = beta2 ** state['step']
160-
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
161-
buffered[1] = n_sma
162-
163-
if n_sma >= self.n_sma_threshold:
164-
step_size = math.sqrt(
165-
(1 - beta2_t)
166-
* (n_sma - 4)
167-
/ (n_sma_max - 4)
168-
* (n_sma - 2)
169-
/ n_sma
170-
* n_sma_max
171-
/ (n_sma_max - 2)
172-
)
173-
if not group['adamd_debias_term']:
174-
step_size /= bias_correction1
175-
elif self.degenerated_to_sgd:
176-
step_size = 1.0 / bias_correction1
177-
else:
178-
step_size = -1
179-
180-
buffered[2] = step_size
151+
buffered[0] = state['step']
152+
beta2_t = beta2 ** state['step']
153+
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
154+
buffered[1] = n_sma
181155

182156
if n_sma >= self.n_sma_threshold:
183-
de_nom = exp_avg_var.sqrt().add_(group['eps'])
184-
p_fp32.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
185-
elif step_size > 0:
186-
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
157+
step_size = math.sqrt(
158+
(1 - beta2_t)
159+
* (n_sma - 4)
160+
/ (n_sma_max - 4)
161+
* (n_sma - 2)
162+
/ n_sma
163+
* n_sma_max
164+
/ (n_sma_max - 2)
165+
)
166+
if not group['adamd_debias_term']:
167+
step_size /= bias_correction1
168+
elif self.degenerated_to_sgd:
169+
step_size = 1.0 / bias_correction1
170+
else:
171+
step_size = -1
172+
173+
buffered[2] = step_size
187174

188-
if p.dtype in (torch.float16, torch.bfloat16):
189-
p.copy_(p_fp32)
175+
if n_sma >= self.n_sma_threshold:
176+
de_nom = exp_avg_var.sqrt().add_(group['eps'])
177+
p.addcdiv_(exp_avg, de_nom, value=-step_size * group['lr'])
178+
elif step_size > 0:
179+
p.add_(exp_avg, alpha=-step_size * group['lr'])
190180

191181
return loss

pytorch_optimizer/optimizer/adabound.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9393

9494
for group, base_lr in zip(self.param_groups, self.base_lrs):
9595
beta1, beta2 = group['betas']
96+
weight_decay: float = group['weight_decay']
9697
for p in group['params']:
9798
if p.grad is None:
9899
continue
@@ -113,13 +114,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
113114
state['step'] += 1
114115
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
115116

116-
if group['weight_decay'] > 0.0:
117+
if weight_decay > 0.0:
117118
if self.weight_decouple:
118-
p.mul_(
119-
1.0 - (group['weight_decay'] if self.fixed_decay else group['lr'] * group['weight_decay'])
120-
)
119+
p.mul_(1.0 - (weight_decay if self.fixed_decay else group['lr'] * weight_decay))
121120
else:
122-
grad.add_(p, alpha=group['weight_decay'])
121+
grad.add_(p, alpha=weight_decay)
123122

124123
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
125124
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

pytorch_optimizer/optimizer/adamp.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
108108
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
109109

110110
bias_correction1 = 1.0 - beta1 ** state['step']
111-
bias_correction2 = 1.0 - beta2 ** state['step']
111+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** state['step'])
112112

113113
if self.use_gc:
114114
grad = centralize_gradient(grad, gc_conv_only=False)
115115

116116
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
117117
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
118118

119-
inv_de_nom = 1.0 / (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
119+
inv_de_nom = 1.0 / (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps'])
120120

121121
perturb = exp_avg.clone()
122122
if group['nesterov']:
@@ -125,7 +125,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
125125
else:
126126
perturb.mul_(inv_de_nom)
127127

128-
wd_ratio: float = 1
128+
wd_ratio: float = 1.0
129129
if len(p.shape) > 1:
130130
perturb, wd_ratio = projection(
131131
p,
@@ -139,10 +139,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
139139
if group['weight_decay'] > 0.0:
140140
p.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)
141141

142-
step_size = group['lr']
143-
if not group['adamd_debias_term']:
144-
step_size /= bias_correction1
145-
142+
step_size: float = group['lr'] if group['adamd_debias_term'] else group['lr'] / bias_correction1
146143
p.add_(perturb, alpha=-step_size)
147144

148145
return loss

pytorch_optimizer/optimizer/adan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
138138
if self.use_gc:
139139
grad = centralize_gradient(grad, gc_conv_only=False)
140140

141-
grad_diff = grad - state['previous_grad']
141+
grad_diff = -state['previous_grad']
142+
grad_diff.add_(grad)
142143
state['previous_grad'].copy_(grad)
143144

144145
update = grad + beta2 * grad_diff

pytorch_optimizer/optimizer/adapnm.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
111111
state['step'] += 1
112112

113113
bias_correction1 = 1 - beta1 ** state['step']
114-
bias_correction2 = 1 - beta2 ** state['step']
114+
bias_correction2_sq = math.sqrt(1 - beta2 ** state['step'])
115115

116116
exp_avg_sq = state['exp_avg_sq']
117117
if state['step'] % 2 == 1:
118118
exp_avg, neg_exp_avg = state['exp_avg'], state['neg_exp_avg']
119119
else:
120120
exp_avg, neg_exp_avg = state['neg_exp_avg'], state['exp_avg']
121121

122-
exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1 - beta1 ** 2) # fmt: skip
122+
exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2) # fmt: skip
123123
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
124124
if group['amsgrad']:
125125
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)
126126

127-
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
128-
129-
step_size = group['lr']
130-
if not group['adamd_debias_term']:
131-
step_size /= bias_correction1
127+
de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps'])
132128

129+
step_size: float = group['lr'] if group['adamd_debias_term'] else group['lr'] / bias_correction1
133130
pn_momentum = exp_avg.mul(1.0 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
134131
p.addcdiv_(pn_momentum, de_nom, value=-step_size)
135132

pytorch_optimizer/optimizer/diffgrad.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,32 +85,31 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8585
state['exp_avg_sq'] = torch.zeros_like(p)
8686
state['previous_grad'] = torch.zeros_like(p)
8787

88+
state['step'] += 1
8889
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']
8990

90-
if group['weight_decay'] != 0:
91+
if group['weight_decay'] > 0.0:
9192
grad.add_(p, alpha=group['weight_decay'])
9293

93-
state['step'] += 1
94-
9594
# Decay the first and second moment running average coefficient
9695
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
9796
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
9897

9998
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
10099

101100
bias_correction1 = 1.0 - beta1 ** state['step']
102-
bias_correction2 = 1.0 - beta2 ** state['step']
101+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** state['step'])
103102

104103
# compute diffGrad coefficient (dfc)
105-
diff = abs(previous_grad - grad)
106-
dfc = 1.0 / (1.0 + torch.exp(-diff))
104+
dfc = previous_grad.clone()
105+
dfc.sub_(grad).abs_().sigmoid_().mul_(exp_avg)
107106
state['previous_grad'].copy_(grad)
108107

109-
step_size = group['lr'] * math.sqrt(bias_correction2)
108+
step_size = group['lr'] * bias_correction2_sq
110109
if not group['adamd_debias_term']:
111110
step_size /= bias_correction1
112111

113-
# update momentum with dfc (exp_avg * dfc)
114-
p.addcdiv_(exp_avg * dfc, de_nom, value=-step_size)
112+
# update momentum with dfc
113+
p.addcdiv_(dfc, de_nom, value=-step_size)
115114

116115
return loss

pytorch_optimizer/optimizer/diffrgrad.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9090
if grad.is_sparse:
9191
raise NoSparseGradientError(self.__name__)
9292

93-
if grad.dtype in (torch.float16, torch.bfloat16):
94-
grad = grad.float()
95-
96-
p_fp32 = p
97-
if p.dtype in (torch.float16, torch.bfloat16):
98-
p_fp32 = p_fp32.float()
99-
10093
state = self.state[p]
10194
if len(state) == 0:
10295
state['step'] = 0
103-
state['exp_avg'] = torch.zeros_like(p_fp32)
104-
state['exp_avg_sq'] = torch.zeros_like(p_fp32)
105-
state['previous_grad'] = torch.zeros_like(p_fp32)
106-
else:
107-
state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
108-
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
109-
state['previous_grad'] = state['previous_grad'].type_as(p_fp32)
96+
state['exp_avg'] = torch.zeros_like(p)
97+
state['exp_avg_sq'] = torch.zeros_like(p)
98+
state['previous_grad'] = torch.zeros_like(p)
11099

111100
state['step'] += 1
112101
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']
@@ -117,8 +106,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
117106
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
118107

119108
# compute diffGrad coefficient (dfc)
120-
diff = abs(previous_grad - grad)
121-
dfc = 1.0 / (1.0 + torch.exp(-diff))
109+
dfc = previous_grad.clone()
110+
dfc.sub_(grad).abs_().sigmoid_().mul_(exp_avg)
122111
state['previous_grad'].copy_(grad)
123112

124113
buffered = group['buffer'][state['step'] % 10]
@@ -149,21 +138,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
149138

150139
buffered[2] = step_size
151140

152-
if n_sma >= self.n_sma_threshold:
153-
if group['weight_decay'] != 0:
154-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
141+
if group['weight_decay'] > 0.0:
142+
p.add_(p, alpha=-group['weight_decay'] * group['lr'])
155143

144+
if n_sma >= self.n_sma_threshold:
156145
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
157-
158-
# update momentum with dfc
159-
p_fp32.addcdiv_(exp_avg * dfc.float(), de_nom, value=-step_size * group['lr'])
146+
p.addcdiv_(dfc, de_nom, value=-step_size * group['lr'])
160147
elif step_size > 0:
161-
if group['weight_decay'] != 0:
162-
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
163-
164-
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
165-
166-
if p.dtype in (torch.float16, torch.bfloat16):
167-
p.copy_(p_fp32)
148+
p.add_(exp_avg, alpha=-step_size * group['lr'])
168149

169150
return loss

pytorch_optimizer/optimizer/gsam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def gradient_decompose(self, alpha: float = 0.0):
140140
if p.grad is None:
141141
continue
142142

143-
vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad.data / (
143+
vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / (
144144
new_grad_norm + self.perturb_eps
145145
)
146146
p.grad.add_(vertical, alpha=-alpha)

pytorch_optimizer/optimizer/lars.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
104104
if 'momentum_buffer' not in param_state:
105105
param_state['momentum_buffer'] = grad.clone().detach()
106106

107-
mu = param_state['momentum_buffer']
108-
mu.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])
107+
mb = param_state['momentum_buffer']
108+
mb.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])
109109

110110
if group['nesterov']:
111-
grad.add_(mu, alpha=group['momentum'])
111+
grad.add_(mb, alpha=group['momentum'])
112112
else:
113-
grad.copy_(mu)
113+
grad.copy_(mb)
114114

115115
p.add_(grad, alpha=-group['lr'])
116116

0 commit comments

Comments
 (0)