Skip to content

Commit 28405f2

Browse files
committed
fix muon with syrk test
Signed-off-by: Hao Wu <[email protected]>
1 parent 3f81996 commit 28405f2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tests/test_orthogonalized_optimizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,22 @@ def test_smoke(self, shape) -> None:
171171
muon_opt = muon.Muon([test_param])
172172
muon_opt.step()
173173

174-
def test_use_syrk(self) -> None:
174+
def test_use_syrk_match_without_syrk(self) -> None:
175175
shape = (32, 32)
176176
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
177-
ref_param = test_param.clone()
177+
ref_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device="cuda"))
178+
ref_param.data.copy_(test_param.data)
178179
test_param.grad = torch.randint_like(test_param, -5, 5)
180+
ref_param.grad = test_param.grad.clone()
179181

180-
muon_opt = muon.Muon([test_param], use_syrk=True)
181-
ref_muon_opt = muon.Muon([test_param], use_syrk=False)
182+
muon_opt = muon.Muon([test_param], num_ns_steps=1, coefficient_type="simple", use_syrk=True)
183+
ref_muon_opt = muon.Muon([ref_param], num_ns_steps=1, coefficient_type="simple", use_syrk=False)
182184
muon_opt.step()
183185
ref_muon_opt.step()
184186

185187
torch.testing.assert_close(
186188
test_param.data,
187189
ref_param.data,
188-
atol=0,
189-
rtol=0,
190190
)
191191

192192

0 commit comments

Comments
 (0)