File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -158,7 +158,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
158158 if len (params ) == 0 :
159159 continue
160160
161- lr = group ['lr' ]
162161 momentum = group ['momentum' ]
163162
164163 total_params : int = sum (p .numel () for p in params )
@@ -196,14 +195,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
196195
197196 self .apply_weight_decay (
198197 p ,
199- g ,
200- lr = lr ,
198+ grad = g ,
199+ lr = group [ 'lr' ] ,
201200 weight_decay = group ['weight_decay' ],
202201 weight_decouple = group ['weight_decouple' ],
203202 fixed_decay = False ,
204203 )
205204
206- lr : float = self .adjust_lr_for_muon (lr , p .size ()) if group ['use_adjusted_lr' ] else lr
205+ lr : float = self .adjust_lr_for_muon (group [ 'lr' ] , p .size ()) if group ['use_adjusted_lr' ] else group [ 'lr' ]
207206
208207 p .add_ (g , alpha = - lr * (max (1.0 , p .size (- 2 ) / p .size (- 1 )) ** 0.5 ))
209208 curr_idx += p .numel ()
You can’t perform that action at this time.
0 commit comments