File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -228,7 +228,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
228228 if 'mask' in state :
229229 grad = grad [state ['mask' ]]
230230
231- if len ( state ) == 0 :
231+ if 'exp_avg' not in state :
232232 state ['exp_avg' ] = torch .zeros_like (grad )
233233 state ['exp_avg_sq' ] = torch .zeros_like (grad )
234234
@@ -258,11 +258,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
258258 else :
259259 p .addcdiv_ (exp_avg , de_nom , value = - step_size * scale_factor )
260260
261- if group ['weight_decay' ] > 0 :
262- if 'mask' in state :
263- p [state ['mask' ]].add_ (p [state ['mask' ]], alpha = - group ['lr' ] * group ['weight_decay' ])
264- else :
265- p .add_ (p , alpha = - group ['lr' ] * group ['weight_decay' ])
261+ self .apply_weight_decay (
262+ p [state ['mask' ]] if 'mask' in state else p ,
263+ grad = None ,
264+ lr = group ['lr' ],
265+ weight_decay = group ['weight_decay' ],
266+ weight_decouple = True ,
267+ fixed_decay = False ,
268+ )
266269
267270 self .state ['total_step' ] += 1
268271 self .state ['current_step' ] += 1
You can’t perform that action at this time.
0 commit comments