1919from absl .testing import absltest , parameterized
2020
2121from emerging_optimizers import utils
22- from emerging_optimizers .orthogonalized_optimizers import muon_utils
23- from emerging_optimizers .orthogonalized_optimizers .muon import Muon , get_muon_scale_factor
24- from emerging_optimizers .orthogonalized_optimizers .muon_utils import _COEFFICIENT_SETS , newton_schulz
22+ from emerging_optimizers .orthogonalized_optimizers import muon , muon_utils
2523
2624
2725def newton_schulz_ref (x : torch .Tensor , coefficient_sets : list [tuple [float , float , float ]]) -> torch .Tensor :
@@ -70,7 +68,7 @@ def tearDown(self):
7068 def test_newtonschulz5_svd_close (self , dim1 , dim2 ):
7169 shape = (dim1 , dim2 )
7270 x = torch .randn (* shape , device = "cuda" , dtype = torch .float32 )
73- out_zeropowerns = newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
71+ out_zeropowerns = muon_utils . newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
7472 U , _ , V = torch .linalg .svd (x , full_matrices = False )
7573 out_zeropower_svd = (U @ V ).float ()
7674 # Check that the outputs are close.
@@ -91,10 +89,10 @@ def test_newtonschulz5_svd_close(self, dim1, dim2):
9189 )
9290 def test_newtonschulz5_close_to_reference (self , dim1 , dim2 ):
9391 x = torch .randn (dim1 , dim2 , device = "cuda" , dtype = torch .float32 )
94- out_zeropower_test = newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
92+ out_zeropower_test = muon_utils . newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
9593 out_zeropowerns_ref = newton_schulz_ref (
9694 x ,
97- coefficient_sets = _COEFFICIENT_SETS ["quintic" ],
95+ coefficient_sets = muon_utils . _COEFFICIENT_SETS ["quintic" ],
9896 )
9997
10098 torch .testing .assert_close (
@@ -116,7 +114,7 @@ def test_newtonschulz_custom_coeff_close_to_reference(self, dim1, dim2):
116114 (3 , 5 , 7 ),
117115 (11 , 13 , 17 ),
118116 ]
119- out_zeropower_test = newton_schulz (
117+ out_zeropower_test = muon_utils . newton_schulz (
120118 x ,
121119 steps = 2 ,
122120 coefficient_type = "custom" ,
@@ -159,8 +157,8 @@ def test_polar_express_better_than_quintic(self, dim1, dim2):
159157
160158 # Compare polar express vs quintic Newton-Schulz methods
161159 out_svd = (u @ v .T ).float ()
162- out_polar_express = newton_schulz (x , steps = 8 , coefficient_type = "polar_express" )
163- out_quintic = newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
160+ out_polar_express = muon_utils . newton_schulz (x , steps = 8 , coefficient_type = "polar_express" )
161+ out_quintic = muon_utils . newton_schulz (x , steps = 5 , coefficient_type = "quintic" )
164162
165163 l2_norm_diff_polar = torch .norm (out_polar_express .float () - out_svd .float (), p = 2 )
166164 l2_norm_diff_quintic = torch .norm (out_quintic .float () - out_svd .float (), p = 2 )
@@ -180,7 +178,7 @@ def test_polar_express_better_than_quintic(self, dim1, dim2):
180178 )
181179 def test_get_scale_factor (self , size_pairs , mode ):
182180 size_out , size_in = size_pairs
183- scale = get_muon_scale_factor (size_out , size_in , mode )
181+ scale = muon . get_muon_scale_factor (size_out , size_in , mode )
184182 if mode == "shape_scaling" :
185183 self .assertEqual (scale , math .sqrt (max (1 , size_out / size_in )))
186184 elif mode == "spectral" :
@@ -196,17 +194,17 @@ def test_qkv_split_shapes_validation(self):
196194 dummy_args = dict (split_qkv = True , is_qkv_fn = lambda x : True )
197195 # Test non-integer values
198196 with self .assertRaises (ValueError ) as cm :
199- Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512.5 , 256 , 256 ))
197+ muon . Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512.5 , 256 , 256 ))
200198 self .assertIn ("must be integers" , str (cm .exception ))
201199
202200 # Test negative values
203201 with self .assertRaises (ValueError ) as cm :
204- Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512 , - 256 , 256 ))
202+ muon . Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512 , - 256 , 256 ))
205203 self .assertIn ("must be positive" , str (cm .exception ))
206204
207205 # Test wrong number of elements
208206 with self .assertRaises (ValueError ) as cm :
209- Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512 , 256 ))
207+ muon . Muon ([dummy_param ], ** dummy_args , qkv_split_shapes = (512 , 256 ))
210208 self .assertIn ("tuple of 3 integers" , str (cm .exception ))
211209
212210
0 commit comments