Skip to content

Commit dede2ed

Browse files
committed
update: test_muon_rank
1 parent 756d7ea commit dede2ed

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_optimizers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import numpy as np
24
import pytest
35
import torch
@@ -13,6 +15,7 @@
1315
WSAM,
1416
DynamicLossScaler,
1517
Lookahead,
18+
Muon,
1619
PCGrad,
1720
load_optimizer,
1821
)
@@ -772,3 +775,23 @@ def test_muon_zero_power_via_newton_schulz_5():
772775

773776
with pytest.raises(ValueError):
774777
_ = zero_power_via_newton_schulz_5(x[0], num_steps=6)
778+
779+
780+
@pytest.mark.parametrize('rank', ['1', '0'])
781+
def test_muon_rank(rank):
782+
os.environ['RANK'] = rank
783+
784+
model = nn.Sequential(
785+
nn.Conv1d(1, 1, 1),
786+
nn.Conv1d(1, 1, 1),
787+
nn.Conv1d(1, 1, 1),
788+
)
789+
790+
optimizer = Muon(model.parameters())
791+
optimizer.zero_grad()
792+
793+
model[0].weight.grad = torch.randn(1, 1, 1)
794+
model[1].weight.grad = torch.randn(1, 1, 1)
795+
model[2].weight.grad = torch.randn(1, 1, 1)
796+
797+
optimizer.step()

0 commit comments

Comments
 (0)