@@ -118,8 +118,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
118118
119119 bias_correction1 = 1 - beta1 ** state ['step' ]
120120
121- exp_avg_sq .mul_ (beta2 ). addcmul_ ( 1 - beta2 , grad , grad )
122- exp_avg .mul_ (beta1 ). add_ ( 1 - beta1 , grad )
121+ exp_avg .mul_ (beta1 ). add_ ( grad , alpha = 1 - beta1 )
122+ exp_avg_sq .mul_ (beta2 ). addcmul_ ( grad , grad , value = 1 - beta2 )
123123
124124 state ['step' ] += 1
125125 buffered = group ['buffer' ][int (state ['step' ] % 10 )]
@@ -155,14 +155,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
155155
156156 if n_sma >= self .n_sma_threshold :
157157 if group ['weight_decay' ] != 0 :
158- p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
158+ p_data_fp32 .add_ (p_data_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
159159 denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
160- p_data_fp32 .addcdiv_ (- step_size * group ['lr' ], exp_avg , denom )
160+ p_data_fp32 .addcdiv_ (exp_avg , denom , value = - step_size * group ['lr' ])
161161 p .data .copy_ (p_data_fp32 )
162162 elif step_size > 0 :
163163 if group ['weight_decay' ] != 0 :
164- p_data_fp32 .add_ (- group ['weight_decay' ] * group ['lr' ], p_data_fp32 )
165- p_data_fp32 .add_ (- step_size * group ['lr' ], exp_avg )
164+ p_data_fp32 .add_ (p_data_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
165+ p_data_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
166166 p .data .copy_ (p_data_fp32 )
167167
168168 return loss
0 commit comments