@@ -121,8 +121,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
121121
122122 bias_correction1 = 1 - beta1 ** state ['step' ]
123123
124- exp_avg .mul_ (beta1 ).add_ (1 - beta1 , grad )
125- exp_avg_sq .mul_ (beta2 ).addcmul_ (1 - beta2 , grad , grad )
124+ exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1 - beta1 )
125+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1 - beta2 )
126126
127127 # compute diffGrad coefficient (dfc)
128128 diff = abs (previous_grad - grad )
@@ -164,18 +164,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
164164
165165 if n_sma >= self .n_sma_threshold :
166166 if group ['weight_decay' ] != 0 :
167- p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
167+ p_data_fp32 .add_ (p_data_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
168168
169169 denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
170170
171171 # update momentum with dfc
172- p_data_fp32 .addcdiv_ (- step_size * group [ 'lr' ], exp_avg * dfc .float (), denom )
172+ p_data_fp32 .addcdiv_ (exp_avg * dfc .float (), denom , value = - step_size * group [ 'lr' ] )
173173 p .data .copy_ (p_data_fp32 )
174174 elif step_size > 0 :
175175 if group ['weight_decay' ] != 0 :
176- p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
176+ p_data_fp32 .add_ (p_data_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
177177
178- p_data_fp32 .add_ (- step_size * group ['lr' ], exp_avg )
178+ p_data_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
179179 p .data .copy_ (p_data_fp32 )
180180
181181 return loss
0 commit comments