Skip to content

Commit 8585308

Browse files
committed
refactor: step
1 parent 9ce4fd4 commit 8585308

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

pytorch_optimizer/adabelief.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
161161
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
162162

163163
if not self.rectify:
164-
if group['adamd_debias_term']:
165-
step_size = group['lr']
166-
else:
167-
step_size = group['lr'] / bias_correction1
168-
164+
step_size = group['lr']
165+
if not group['adamd_debias_term']:
166+
step_size /= bias_correction1
169167
p.data.addcdiv_(exp_avg, denom, value=-step_size)
170168
else:
171169
buffered = group['buffer'][int(state['step'] % 10)]
@@ -189,10 +187,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
189187
/ (n_sma_max - 2)
190188
)
191189

192-
if group['adamd_debias_term']:
193-
step_size = rt
194-
else:
195-
step_size = rt / bias_correction1
190+
step_size = rt
191+
if not group['adamd_debias_term']:
192+
step_size /= bias_correction1
196193
elif self.degenerated_to_sgd:
197194
step_size = 1.0 / bias_correction1
198195
else:

0 commit comments

Comments
 (0)