Skip to content

Commit 3d36221

Browse files
committed
refactor: add_, addcuml_, addcdiv_
1 parent 0b82d95 commit 3d36221

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pytorch_optimizer/diffrgrad.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
121121

122122
bias_correction1 = 1 - beta1 ** state['step']
123123

124-
exp_avg.mul_(beta1).add_(1 - beta1, grad)
125-
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
124+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
125+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
126126

127127
# compute diffGrad coefficient (dfc)
128128
diff = abs(previous_grad - grad)
@@ -164,18 +164,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
164164

165165
if n_sma >= self.n_sma_threshold:
166166
if group['weight_decay'] != 0:
167-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
167+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
168168

169169
denom = exp_avg_sq.sqrt().add_(group['eps'])
170170

171171
# update momentum with dfc
172-
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
172+
p_data_fp32.addcdiv_(exp_avg * dfc.float(), denom, value=-step_size * group['lr'])
173173
p.data.copy_(p_data_fp32)
174174
elif step_size > 0:
175175
if group['weight_decay'] != 0:
176-
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
176+
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
177177

178-
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
178+
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
179179
p.data.copy_(p_data_fp32)
180180

181181
return loss

0 commit comments

Comments
 (0)