Skip to content

Commit 26fca7b

Browse files
committed
refactor: optimizers
1 parent 19c2136 commit 26fca7b

File tree

11 files changed

+25
-36
lines changed

11 files changed

+25
-36
lines changed

pytorch_optimizer/optimizer/adabelief.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9595

9696
for group in self.param_groups:
9797
beta1, beta2 = group['betas']
98+
n_sma_max: float = 2 / (1 - beta2) - 1
9899
for p in group['params']:
99100
if p.grad is None:
100101
continue
@@ -154,12 +155,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
154155
else:
155156
buffered[0] = state['step']
156157
beta2_t = beta2 ** state['step']
157-
n_sma_max = 2 / (1 - beta2) - 1
158158
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
159159
buffered[1] = n_sma
160160

161161
if n_sma >= self.n_sma_threshold:
162-
rt = math.sqrt(
162+
step_size = math.sqrt(
163163
(1 - beta2_t)
164164
* (n_sma - 4)
165165
/ (n_sma_max - 4)
@@ -168,8 +168,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
168168
* n_sma_max
169169
/ (n_sma_max - 2)
170170
)
171-
172-
step_size = rt
173171
if not group['adamd_debias_term']:
174172
step_size /= bias_correction1
175173
elif self.degenerated_to_sgd:

pytorch_optimizer/optimizer/adabound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
113113
state['step'] += 1
114114
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
115115

116-
if group['weight_decay'] != 0:
116+
if group['weight_decay'] > 0.0:
117117
if self.weight_decouple:
118118
p.mul_(
119119
1.0 - (group['weight_decay'] if self.fixed_decay else group['lr'] * group['weight_decay'])
@@ -124,7 +124,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
124124
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
125125
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
126126
if group['amsbound']:
127-
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
127+
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)
128128

129129
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
130130

pytorch_optimizer/optimizer/adai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
113113

114114
bias_correction2 = 1.0 - beta2 ** state['step']
115115

116-
if group['weight_decay'] != 0:
116+
if group['weight_decay'] > 0.0:
117117
if self.weight_decouple:
118118
p.mul_(1.0 - group['lr'] * group['weight_decay'])
119119
else:

pytorch_optimizer/optimizer/adamp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
136136
group['eps'],
137137
)
138138

139-
if group['weight_decay'] > 0:
139+
if group['weight_decay'] > 0.0:
140140
p.mul_(1.0 - group['lr'] * group['weight_decay'] * wd_ratio)
141141

142142
step_size = group['lr']

pytorch_optimizer/optimizer/adapnm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
122122
exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1 - beta1 ** 2) # fmt: skip
123123
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
124124
if group['amsgrad']:
125-
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
125+
torch.max(state['max_exp_avg_sq'], exp_avg_sq, out=exp_avg_sq)
126126

127127
de_nom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
128128

pytorch_optimizer/optimizer/diffrgrad.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8181

8282
for group in self.param_groups:
8383
beta1, beta2 = group['betas']
84+
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
8485
for p in group['params']:
8586
if p.grad is None:
8687
continue
@@ -107,9 +108,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107108
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
108109
state['previous_grad'] = state['previous_grad'].type_as(p_fp32)
109110

110-
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']
111-
112111
state['step'] += 1
112+
exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']
113113

114114
bias_correction1 = 1.0 - beta1 ** state['step']
115115

@@ -127,12 +127,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
127127
else:
128128
buffered[0] = state['step']
129129
beta2_t = beta2 ** state['step']
130-
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
131130
n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t)
132131
buffered[1] = n_sma
133132

134133
if n_sma >= self.n_sma_threshold:
135-
rt = math.sqrt(
134+
step_size = math.sqrt(
136135
(1 - beta2_t)
137136
* (n_sma - 4)
138137
/ (n_sma_max - 4)
@@ -141,8 +140,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
141140
* n_sma_max
142141
/ (n_sma_max - 2)
143142
)
144-
145-
step_size = rt
146143
if not group['adamd_debias_term']:
147144
step_size /= bias_correction1
148145
elif self.degenerated_to_sgd:

pytorch_optimizer/optimizer/lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
142142
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
143143

144144
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
145-
if group['weight_decay'] != 0:
145+
if group['weight_decay'] > 0.0:
146146
adam_step.add_(p, alpha=group['weight_decay'])
147147

148148
weight_norm = p.norm(2).clamp(0, self.clamp)

pytorch_optimizer/optimizer/madgrad.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8686
for group in self.param_groups:
8787
eps = group['eps']
8888
lr = group['lr'] + eps
89-
decay = group['weight_decay']
89+
weight_decay = group['weight_decay']
9090
momentum = group['momentum']
9191

9292
ck: float = 1.0 - momentum
@@ -111,15 +111,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
111111
grad_sum_sq = state['grad_sum_sq']
112112
s = state['s']
113113

114-
if decay != 0 and not self.decouple_decay:
114+
if weight_decay > 0.0 and not self.decouple_decay:
115115
if grad.is_sparse:
116116
raise NoSparseGradientError(self.__name__, note='weight_decay')
117117

118118
# original implementation
119-
grad.add_(p, alpha=decay)
119+
grad.add_(p, alpha=weight_decay)
120120

121121
# Apply weight decay - L2 / AdamW style
122-
# p.mul_(1.0 - lr * decay)
122+
# p.mul_(1.0 - lr * weight_decay)
123123

124124
if grad.is_sparse:
125125
grad = grad.coalesce()
@@ -167,7 +167,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
167167

168168
s.add_(grad, alpha=_lambda)
169169

170-
if decay != 0 and self.decouple_decay:
170+
if weight_decay > 0.0 and self.decouple_decay:
171171
p_old = p.clone()
172172

173173
if momentum == 0.0:
@@ -176,8 +176,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176176
z = x0.addcdiv(s, rms, value=-1)
177177
p.mul_(1.0 - ck).add_(z, alpha=ck)
178178

179-
if decay != 0 and self.decouple_decay:
180-
p.add_(p_old, alpha=-lr * decay)
179+
if weight_decay > 0.0 and self.decouple_decay:
180+
p.add_(p_old, alpha=-lr * weight_decay)
181181

182182
self.state['k'] += 1
183183

pytorch_optimizer/optimizer/radam.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8181

8282
for group in self.param_groups:
8383
beta1, beta2 = group['betas']
84+
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
8485
for p in group['params']:
8586
if p.grad is None:
8687
continue
@@ -120,12 +121,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
120121
else:
121122
buffered[0] = state['step']
122123
beta2_t = beta2 ** state['step']
123-
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
124124
n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t)
125125
buffered[1] = n_sma
126126

127127
if n_sma >= self.n_sma_threshold:
128-
rt = math.sqrt(
128+
step_size = math.sqrt(
129129
(1 - beta2_t)
130130
* (n_sma - 4)
131131
/ (n_sma_max - 4)
@@ -134,8 +134,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
134134
* n_sma_max
135135
/ (n_sma_max - 2)
136136
)
137-
138-
step_size = rt
139137
if not group['adamd_debias_term']:
140138
step_size /= bias_correction1
141139
elif self.degenerated_to_sgd:
@@ -144,7 +142,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
144142
step_size = -1
145143
buffered[2] = step_size
146144

147-
if group['weight_decay'] != 0 and (n_sma >= self.n_sma_threshold or step_size > 0):
145+
if group['weight_decay'] > 0.0 and (n_sma >= self.n_sma_threshold or step_size > 0):
148146
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
149147

150148
if n_sma >= self.n_sma_threshold:

pytorch_optimizer/optimizer/ralamb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
104104

105105
for group in self.param_groups:
106106
beta1, beta2 = group['betas']
107+
n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
107108
for p in group['params']:
108109
if p.grad is None:
109110
continue
@@ -147,13 +148,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147148
else:
148149
buffered[0] = state['step']
149150
beta2_t = beta2 ** state['step']
150-
n_sma_max = 2 / (1 - beta2) - 1
151151
n_sma = n_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
152152
buffered[1] = n_sma
153153

154154
# more conservative since it's an approximated value
155155
if n_sma >= self.n_sma_threshold:
156-
rt = math.sqrt(
156+
step_size = math.sqrt(
157157
(1 - beta2_t)
158158
* (n_sma - 4)
159159
/ (n_sma_max - 4)
@@ -162,8 +162,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
162162
* n_sma_max
163163
/ (n_sma_max - 2)
164164
)
165-
166-
step_size = rt
167165
if not group['adamd_debias_term']:
168166
step_size /= bias_correction1
169167
elif self.degenerated_to_sgd:

0 commit comments

Comments
 (0)