Skip to content

Commit 6071805

Browse files
authored
[Feature] Support cautious variant for Muon optimizer (#417)
* docs: v3.8.0 changelog * feature: support cautious variant * update: test recipe * docs: v3.8.0 changelog * update: visualize_optimizers * docs: v3.8.0 changelog * docs: visualization * build(deps): mkdocs plugins * refactor: awesome-nav * ci: update labeler
1 parent 6704a22 commit 6071805

File tree

12 files changed

+192
-153
lines changed

12 files changed

+192
-153
lines changed

.github/labeler.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ dependencies:
1212
- changed-files:
1313
- any-glob-to-any-file:
1414
- pyproject.toml
15+
- requirements.txt
16+
- requirements-dev.txt
17+
- requirements-docs.txt
1518

1619
optimizer:
1720
- changed-files:

docs/.nav.yml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
preserve_directory_names: true
2+
3+
flatten_single_child_sections: false
4+
5+
sort:
6+
direction: asc
7+
sections: last
8+
type: alphabetical
9+
ignore_case: false
10+
11+
nav:
12+
- Home: index.md
13+
- Optimizer: optimizer.md
14+
- LR Scheduler: lr_scheduler.md
15+
- Loss Function: loss.md
16+
- Utilization: util.md
17+
- Base: base.md
18+
- Visualization: visualization.md
19+
- Change Logs:
20+
- changelogs/*.md
21+
- Q&A: qa.md

docs/changelogs/v3.8.0.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,24 @@
66
* [Through the River: Understanding the Benefit of Schedule-Free Methods for Language Model Training](https://arxiv.org/abs/2507.09846)
77
* You can use this variant by setting `decoupling_c` parameter in the `ScheduleFreeAdamW` optimizer.
88
* Add more built-in optimizers, `NAdam`, `RMSProp`, and `LBFGS` optimizers. (#415)
9+
* Support `cautious` variant for `Muon` optimizer. (#417)
910

1011
### Update
1112

1213
* Re-implement `Muon` and `AdaMuon` optimizers based on the recent official implementation. (#408, #410)
1314
* Their definitions have changed from the previous version, so please check out the documentation!
1415
* Update the missing optimizers from `__init__.py`. (#415)
16+
* Add the HuggingFace Trainer example. (#415)
17+
* Optimize the visualization outputs and change the visualization document to a table layout. (#416)
18+
19+
### Dependency
20+
21+
* Update `mkdocs` dependencies. (#417)
1522

1623
### CI
1724

1825
* Add some GitHub actions to automate some processes. (#411, #412, #413)
1926

20-
### Example
27+
## Contributions
2128

22-
* Add the HuggingFace Trainer example. (#415)
29+
thanks to @AidinHamedi

docs/visualization.md

Lines changed: 104 additions & 104 deletions
Large diffs are not rendered by default.

examples/visualize_optimizers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
filterwarnings('ignore', category=UserWarning)
1818

19-
IMG_FORMAT: str = 'jpg'
2019
OPTIMIZERS_IGNORE: Tuple[str, ...] = (
2120
'lomo',
2221
'adalomo',
@@ -42,8 +41,8 @@
4241
SEARCH_SEED: int = 42
4342
LOSS_MIN_THRESHOLD: float = 0.0
4443

45-
DEFAULT_SEARCH_SPACES: Dict[str, object] = {'lr': hp.uniform('lr', 0, 2)}
46-
SPECIAL_SEARCH_SPACES: Dict[str, Dict[str, object]] = {
44+
DEFAULT_SEARCH_SPACES: Dict = {'lr': hp.uniform('lr', 0, 2)}
45+
SPECIAL_SEARCH_SPACES: Dict = {
4746
'adafactor': {'lr': hp.uniform('lr', 0, 10)},
4847
'adams': {'lr': hp.uniform('lr', 0, 10)},
4948
'dadaptadagrad': {'lr': hp.uniform('lr', 0, 10)},
@@ -407,7 +406,7 @@ def execute_experiments(
407406
"""
408407
for i, (optimizer_class, search_space) in enumerate(optimizers, start=1):
409408
optimizer_name = optimizer_class.__name__
410-
output_path = output_dir / f'{experiment_name}_{optimizer_name}.{IMG_FORMAT}'
409+
output_path = output_dir / f'{experiment_name}_{optimizer_name}.jpg'
411410
if output_path.exists():
412411
continue
413412

mkdocs.yml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,6 @@ site_name: pytorch-optimizer
22
site_description: 'optimizer & lr scheduler & loss function collections in PyTorch'
33
repo_name: 'kozistr/pytorch-optimizer'
44
repo_url: 'https://github.com/kozistr/pytorch_optimizer'
5-
nav:
6-
- index.md
7-
- base.md
8-
- optimizer.md
9-
- lr_scheduler.md
10-
- loss.md
11-
- util.md
12-
- visualization.md
13-
- ... | changelogs/*.md
14-
- qa.md
155
theme:
166
name: material
177
highlightjs: true
@@ -21,7 +11,7 @@ extra_javascript:
2111
- javascripts/tables.js
2212
plugins:
2313
- search
24-
- awesome-pages
14+
- awesome-nav
2515
- mkdocstrings:
2616
handlers:
2717
python:

pytorch_optimizer/optimizer/muon.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123

124124
group['weight_decouple'] = group.get('weight_decouple', weight_decouple)
125125

126-
super().__init__(params, {})
126+
super().__init__(params, kwargs)
127127

128128
def __str__(self) -> str:
129129
return 'Muon'
@@ -192,6 +192,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
192192

193193
update = zero_power_via_newton_schulz_5(update, num_steps=group['ns_steps'])
194194

195+
if group.get('cautious'):
196+
self.apply_cautious(update, grad)
197+
195198
lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])
196199

197200
p.add_(update.reshape(p.shape), alpha=-lr)
@@ -308,7 +311,7 @@ def __init__(
308311
group['weight_decouple'] = group.get('weight_decouple', weight_decouple)
309312
group['eps'] = group.get('eps', eps)
310313

311-
super().__init__(params, {})
314+
super().__init__(params, kwargs)
312315

313316
def __str__(self) -> str:
314317
return 'AdaMuon'

requirements-docs.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
--index-url https://pypi.org/simple
22
--extra-index-url https://download.pytorch.org/whl/cpu
33
numpy<2.0
4-
torch==2.6.0
4+
torch==2.8.0
55
mkdocs==1.6.1
6-
mkdocs-material==9.5.45
7-
pymdown-extensions==10.12
6+
mkdocs-material==9.6.16
7+
pymdown-extensions==10.16.1
88
mkdocstrings-python==1.12.2
99
markdown-include==0.8.1
1010
mdx_truly_sane_lists==1.3
11-
mkdocs-awesome-pages-plugin==2.9.3
11+
mkdocs-awesome-nav==3.1.2
1212
griffe==1.5.1

tests/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,18 @@
721721
(AdaShift, {'lr': 1e1, 'keep_num': 1}, 3),
722722
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3}, 3),
723723
(MARS, {'lr': 5e-1, 'lr_1d': 5e-1, 'weight_decay': 1e-3, 'optimize_1d': True}, 3),
724+
(
725+
Muon,
726+
{
727+
'lr': 5e-1,
728+
'weight_decay': 1e-3,
729+
'use_adjusted_lr': True,
730+
'adamw_lr': 5e-1,
731+
'adamw_betas': (0.9, 0.98),
732+
'adamw_wd': 1e-2,
733+
},
734+
7,
735+
),
724736
]
725737
STABLE_ADAMW_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
726738
(ADOPT, {'lr': 1e0, 'weight_decay': 1e-3, 'stable_adamw': True}, 5),

tests/test_optimizer_variants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
COPT_SUPPORTED_OPTIMIZERS,
99
STABLE_ADAMW_SUPPORTED_OPTIMIZERS,
1010
)
11-
from tests.utils import build_model, ids, simple_parameter, tensor_to_numpy
11+
from tests.utils import build_model, build_optimizer_parameter, ids, simple_parameter, tensor_to_numpy
1212

1313

1414
@pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids)
@@ -80,11 +80,14 @@ def test_adamd_variant(optimizer_config, environment):
8080
@pytest.mark.parametrize('optimizer_config', COPT_SUPPORTED_OPTIMIZERS, ids=ids)
8181
def test_cautious_variant(optimizer_config, environment):
8282
x_data, y_data = environment
83+
8384
model, loss_fn = build_model()
8485

8586
optimizer_class, config, num_iterations = optimizer_config
8687

87-
optimizer = optimizer_class(model.parameters(), **config, cautious=True)
88+
parameters, config = build_optimizer_parameter(model.parameters(), optimizer_class.__name__, config)
89+
90+
optimizer = optimizer_class(parameters, **config, cautious=True)
8891

8992
init_loss, loss = np.inf, np.inf
9093
for _ in range(num_iterations):

0 commit comments

Comments
 (0)