Skip to content

Commit 4e3c130

Browse files
committed
fix: lr ratio
1 parent cf01ecd commit 4e3c130

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

pytorch_optimizer/optimizer/muon.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import os
3-
from typing import List, Optional
3+
from typing import List, Optional, Tuple
44

55
import torch
66
from torch.distributed import ReduceOp, all_reduce
@@ -131,9 +131,18 @@ def reset(self):
131131
state['moment2'] = torch.zeros_like(p)
132132

133133
@staticmethod
134-
def adjust_lr_for_muon(lr: float, param_shape) -> float:
135-
adjusted_ratio: float = 0.2 * math.sqrt(max(param_shape[0], param_shape[1]))
136-
return lr * adjusted_ratio
134+
def get_adjusted_lr(lr: float, param_shape: Tuple[float, ...], use_adjusted_lr: bool = False) -> float:
135+
r"""Get the adjust learning rate."""
136+
output_shape, *input_shape = param_shape
137+
input_shape = math.prod(input_shape)
138+
139+
ratio: float = (
140+
math.pow(max(1.0, output_shape / input_shape), 0.5)
141+
if use_adjusted_lr
142+
else 0.2 * math.sqrt(max(output_shape, input_shape))
143+
)
144+
145+
return lr * ratio
137146

138147
@torch.no_grad()
139148
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -202,9 +211,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
202211
fixed_decay=False,
203212
)
204213

205-
lr: float = self.adjust_lr_for_muon(group['lr'], p.size()) if group['use_adjusted_lr'] else group['lr']
214+
lr: float = self.get_adjusted_lr(group['lr'], p.size(), group['use_adjusted_lr'])
206215

207-
p.add_(g, alpha=-lr * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5))
216+
p.add_(g, alpha=-lr)
208217
curr_idx += p.numel()
209218

210219
params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]

0 commit comments

Comments
 (0)