File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change 1+ import os
2+
13import numpy as np
24import pytest
35import torch
1315 WSAM ,
1416 DynamicLossScaler ,
1517 Lookahead ,
18+ Muon ,
1619 PCGrad ,
1720 load_optimizer ,
1821)
@@ -772,3 +775,23 @@ def test_muon_zero_power_via_newton_schulz_5():
772775
773776 with pytest .raises (ValueError ):
774777 _ = zero_power_via_newton_schulz_5 (x [0 ], num_steps = 6 )
778+
779+
780+ @pytest .mark .parametrize ('rank' , ['1' , '0' ])
781+ def test_muon_rank (rank ):
782+ os .environ ['RANK' ] = rank
783+
784+ model = nn .Sequential (
785+ nn .Conv1d (1 , 1 , 1 ),
786+ nn .Conv1d (1 , 1 , 1 ),
787+ nn .Conv1d (1 , 1 , 1 ),
788+ )
789+
790+ optimizer = Muon (model .parameters ())
791+ optimizer .zero_grad ()
792+
793+ model [0 ].weight .grad = torch .randn (1 , 1 , 1 )
794+ model [1 ].weight .grad = torch .randn (1 , 1 , 1 )
795+ model [2 ].weight .grad = torch .randn (1 , 1 , 1 )
796+
797+ optimizer .step ()
You can’t perform that action at this time.
0 commit comments