Skip to content

Commit 0ad115c

Browse files
authored
Merge pull request #373 from kozistr/fix/muon-optimizer
[Fix] Correct the learning rate ratio in `Muon` optimizer
2 parents d4e7564 + 40238f3 commit 0ad115c

File tree

6 files changed

+206
-129
lines changed

6 files changed

+206
-129
lines changed

docs/changelogs/v3.5.1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@
77
### Update
88

99
* Update `SCION` optimizer based on the official implementation. (#369)
10+
11+
### Fix
12+
13+
* Correct the learning rate ratio in `Muon` optimizer properly. (#371, #372, #373)

poetry.lock

Lines changed: 172 additions & 111 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/optimizer/muon.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
import os
3-
from typing import List, Optional
3+
from typing import List, Optional, Tuple
44

55
import torch
66
from torch.distributed import ReduceOp, all_reduce
@@ -131,9 +131,18 @@ def reset(self):
131131
state['moment2'] = torch.zeros_like(p)
132132

133133
@staticmethod
134-
def adjust_lr_for_muon(lr: float, param_shape) -> float:
135-
adjusted_ratio: float = 0.2 * math.sqrt(max(param_shape[0], param_shape[1]))
136-
return lr * adjusted_ratio
134+
def get_adjusted_lr(lr: float, param_shape: Tuple[float, ...], use_adjusted_lr: bool = False) -> float:
135+
r"""Get the adjust learning rate."""
136+
output_shape, *input_shape = param_shape
137+
input_shape = math.prod(input_shape)
138+
139+
ratio: float = (
140+
math.pow(max(1.0, output_shape / input_shape), 0.5)
141+
if use_adjusted_lr
142+
else 0.2 * math.sqrt(max(output_shape, input_shape))
143+
)
144+
145+
return lr * ratio
137146

138147
@torch.no_grad()
139148
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -202,9 +211,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
202211
fixed_decay=False,
203212
)
204213

205-
lr: float = self.adjust_lr_for_muon(group['lr'], p.size()) if group['use_adjusted_lr'] else group['lr']
214+
lr: float = self.get_adjusted_lr(group['lr'], p.size(), group['use_adjusted_lr'])
206215

207-
p.add_(g, alpha=-lr * (max(1.0, p.size(-2) / p.size(-1)) ** 0.5))
216+
p.add_(g, alpha=-lr)
208217
curr_idx += p.numel()
209218

210219
params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]

requirements-dev.txt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ black==25.1.0 ; python_version >= "3.9"
55
click==8.1.8 ; python_version >= "3.8"
66
colorama==0.4.6 ; python_version >= "3.8" and (sys_platform == "win32" or platform_system == "Windows")
77
coverage[toml]==7.6.1 ; python_version == "3.8"
8-
coverage[toml]==7.6.12 ; python_version >= "3.9"
8+
coverage[toml]==7.8.0 ; python_version >= "3.9"
99
exceptiongroup==1.2.2 ; python_version < "3.11" and python_version >= "3.8"
1010
filelock==3.16.1 ; python_version == "3.8"
1111
filelock==3.18.0 ; python_version >= "3.9"
12-
fsspec==2025.3.0 ; python_version >= "3.8"
13-
iniconfig==2.0.0 ; python_version >= "3.8"
12+
fsspec==2025.3.0 ; python_version == "3.8"
13+
fsspec==2025.3.2 ; python_version >= "3.9"
14+
iniconfig==2.1.0 ; python_version >= "3.8"
1415
isort==5.13.2 ; python_version == "3.8"
1516
isort==6.0.1 ; python_version >= "3.9"
1617
jinja2==3.1.6 ; python_version >= "3.8"
@@ -22,17 +23,18 @@ networkx==3.1 ; python_version == "3.8"
2223
networkx==3.2.1 ; python_version >= "3.9"
2324
numpy==1.24.4 ; python_version == "3.8"
2425
numpy==2.0.2 ; python_version >= "3.9"
25-
packaging==24.2 ; python_version >= "3.8"
26+
packaging==25.0 ; python_version >= "3.8"
2627
pathspec==0.12.1 ; python_version >= "3.8"
27-
platformdirs==4.3.6 ; python_version >= "3.8"
28+
platformdirs==4.3.6 ; python_version == "3.8"
29+
platformdirs==4.3.7 ; python_version >= "3.9"
2830
pluggy==1.5.0 ; python_version >= "3.8"
2931
pytest-cov==5.0.0 ; python_version >= "3.8"
3032
pytest==8.3.5 ; python_version >= "3.8"
31-
ruff==0.11.0 ; python_version >= "3.8"
32-
setuptools==76.0.0 ; python_version >= "3.12"
33+
ruff==0.11.6 ; python_version >= "3.8"
34+
setuptools==79.0.0 ; python_version >= "3.12"
3335
sympy==1.13.1 ; python_version >= "3.9"
3436
sympy==1.13.3 ; python_version == "3.8"
3537
tomli==2.2.1 ; python_full_version <= "3.11.0a6" and python_version >= "3.8"
3638
torch==2.4.1+cpu ; python_version == "3.8"
3739
torch==2.6.0+cpu ; python_version >= "3.9"
38-
typing-extensions==4.12.2 ; python_version >= "3.8"
40+
typing-extensions==4.13.2 ; python_version >= "3.8"

requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
filelock==3.16.1 ; python_version == "3.8"
44
filelock==3.18.0 ; python_version >= "3.9"
5-
fsspec==2025.3.0 ; python_version >= "3.8"
5+
fsspec==2025.3.0 ; python_version == "3.8"
6+
fsspec==2025.3.2 ; python_version >= "3.9"
67
jinja2==3.1.6 ; python_version >= "3.8"
78
markupsafe==2.1.5 ; python_version == "3.8"
89
markupsafe==3.0.2 ; python_version >= "3.9"
@@ -11,9 +12,9 @@ networkx==3.1 ; python_version == "3.8"
1112
networkx==3.2.1 ; python_version >= "3.9"
1213
numpy==1.24.4 ; python_version == "3.8"
1314
numpy==2.0.2 ; python_version >= "3.9"
14-
setuptools==76.0.0 ; python_version >= "3.12"
15+
setuptools==79.0.0 ; python_version >= "3.12"
1516
sympy==1.13.1 ; python_version >= "3.9"
1617
sympy==1.13.3 ; python_version == "3.8"
1718
torch==2.4.1+cpu ; python_version == "3.8"
1819
torch==2.6.0+cpu ; python_version >= "3.9"
19-
typing-extensions==4.12.2 ; python_version >= "3.8"
20+
typing-extensions==4.13.2 ; python_version >= "3.8"

tests/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@
525525
(ADOPT, {'lr': 1e0}, 5),
526526
(FTRL, {'lr': 1e0, 'beta': 0.0, 'lambda_1': 0.0, 'lambda_2': 0.0}, 5),
527527
(Muon, {'lr': 5e0, 'use_adjusted_lr': True, 'adam_lr': 1e0, 'adamw_wd': 1e-2}, 5),
528-
(Muon, {'lr': 1e0, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
528+
(Muon, {'lr': 2e0, 'adam_lr': 1e0, 'adamw_wd': 1e-2, 'nesterov': False}, 5),
529529
(LaProp, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
530530
(LaProp, {'lr': 1e0, 'centered': True, 'weight_decay': 1e-3}, 11),
531531
(LaProp, {'lr': 1e0, 'ams_bound': True, 'weight_decay': 1e-3}, 5),

0 commit comments

Comments
 (0)