File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -152,14 +152,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
152152 step_size = - 1
153153 buffered [2 ] = step_size
154154
155+ if (n_sma >= self .n_sma_threshold or step_size > 0 ) and group ['weight_decay' ] != 0 :
156+ p_fp32 .add_ (p_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
157+
155158 if n_sma >= self .n_sma_threshold :
156- if group ['weight_decay' ] != 0 :
157- p_fp32 .add_ (p_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
158159 de_nom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
159160 p_fp32 .addcdiv_ (exp_avg , de_nom , value = - step_size * group ['lr' ])
160161 elif step_size > 0 :
161- if group ['weight_decay' ] != 0 :
162- p_fp32 .add_ (p_fp32 , alpha = - group ['weight_decay' ] * group ['lr' ])
163162 p_fp32 .add_ (exp_avg , alpha = - step_size * group ['lr' ])
164163
165164 if p .dtype in (torch .float16 , torch .bfloat16 ):
You can’t perform that action at this time.
0 commit comments