Skip to content

Commit a57113e

Browse files
committed
refactor: add_, addcuml_, addcdiv_
1 parent 3d36221 commit a57113e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pytorch_optimizer/radam.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118118

119119
bias_correction1 = 1 - beta1 ** state['step']
120120

121-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
122-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
121+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
122+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
123123

124124
state['step'] += 1
125125
buffered = group['buffer'][int(state['step'] % 10)]
@@ -155,14 +155,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
155155

156156
if n_sma >= self.n_sma_threshold:
157157
if group['weight_decay'] != 0:
158-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
158+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
159159
denom = exp_avg_sq.sqrt().add_(group['eps'])
160-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
160+
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
161161
p.data.copy_(p_data_fp32)
162162
elif step_size > 0:
163163
if group['weight_decay'] != 0:
164-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
165-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
164+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
165+
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
166166
p.data.copy_(p_data_fp32)
167167

168168
return loss

0 commit comments

Comments
 (0)