Skip to content

Commit 2cf72e5

Browse files
committed
fix: lr
1 parent 8dbf0f0 commit 2cf72e5

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pytorch_optimizer/optimizer/muon.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
158158
if len(params) == 0:
159159
continue
160160

161-
lr = group['lr']
162161
momentum = group['momentum']
163162

164163
total_params: int = sum(p.numel() for p in params)
@@ -196,14 +195,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
196195

197196
self.apply_weight_decay(
198197
p,
199-
g,
200-
lr=lr,
198+
grad=g,
199+
lr=group['lr'],
201200
weight_decay=group['weight_decay'],
202201
weight_decouple=group['weight_decouple'],
203202
fixed_decay=False,
204203
)
205204

206-
lr: float = self.adjust_lr_for_muon(lr, p.size()) if group['use_adjusted_lr'] else lr
205+
lr: float = self.adjust_lr_for_muon(group['lr'], p.size()) if group['use_adjusted_lr'] else group['lr']
207206

208207
p.add_(g, alpha=-lr * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5))
209208
curr_idx += p.numel()

0 commit comments

Comments
 (0)