Skip to content

Commit ef94309

Browse files
committed
import module instead of function
Signed-off-by: Hao Wu <[email protected]>
1 parent 28405f2 commit ef94309

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

tests/test_muon_utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from absl.testing import absltest, parameterized
2020

2121
from emerging_optimizers import utils
22-
from emerging_optimizers.orthogonalized_optimizers import muon_utils
23-
from emerging_optimizers.orthogonalized_optimizers.muon import Muon, get_muon_scale_factor
24-
from emerging_optimizers.orthogonalized_optimizers.muon_utils import _COEFFICIENT_SETS, newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers import muon, muon_utils
2523

2624

2725
def newton_schulz_ref(x: torch.Tensor, coefficient_sets: list[tuple[float, float, float]]) -> torch.Tensor:
@@ -70,7 +68,7 @@ def tearDown(self):
7068
def test_newtonschulz5_svd_close(self, dim1, dim2):
7169
shape = (dim1, dim2)
7270
x = torch.randn(*shape, device="cuda", dtype=torch.float32)
73-
out_zeropowerns = newton_schulz(x, steps=5, coefficient_type="quintic")
71+
out_zeropowerns = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
7472
U, _, V = torch.linalg.svd(x, full_matrices=False)
7573
out_zeropower_svd = (U @ V).float()
7674
# Check that the outputs are close.
@@ -91,10 +89,10 @@ def test_newtonschulz5_svd_close(self, dim1, dim2):
9189
)
9290
def test_newtonschulz5_close_to_reference(self, dim1, dim2):
9391
x = torch.randn(dim1, dim2, device="cuda", dtype=torch.float32)
94-
out_zeropower_test = newton_schulz(x, steps=5, coefficient_type="quintic")
92+
out_zeropower_test = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
9593
out_zeropowerns_ref = newton_schulz_ref(
9694
x,
97-
coefficient_sets=_COEFFICIENT_SETS["quintic"],
95+
coefficient_sets=muon_utils._COEFFICIENT_SETS["quintic"],
9896
)
9997

10098
torch.testing.assert_close(
@@ -116,7 +114,7 @@ def test_newtonschulz_custom_coeff_close_to_reference(self, dim1, dim2):
116114
(3, 5, 7),
117115
(11, 13, 17),
118116
]
119-
out_zeropower_test = newton_schulz(
117+
out_zeropower_test = muon_utils.newton_schulz(
120118
x,
121119
steps=2,
122120
coefficient_type="custom",
@@ -159,8 +157,8 @@ def test_polar_express_better_than_quintic(self, dim1, dim2):
159157

160158
# Compare polar express vs quintic Newton-Schulz methods
161159
out_svd = (u @ v.T).float()
162-
out_polar_express = newton_schulz(x, steps=8, coefficient_type="polar_express")
163-
out_quintic = newton_schulz(x, steps=5, coefficient_type="quintic")
160+
out_polar_express = muon_utils.newton_schulz(x, steps=8, coefficient_type="polar_express")
161+
out_quintic = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic")
164162

165163
l2_norm_diff_polar = torch.norm(out_polar_express.float() - out_svd.float(), p=2)
166164
l2_norm_diff_quintic = torch.norm(out_quintic.float() - out_svd.float(), p=2)
@@ -180,7 +178,7 @@ def test_polar_express_better_than_quintic(self, dim1, dim2):
180178
)
181179
def test_get_scale_factor(self, size_pairs, mode):
182180
size_out, size_in = size_pairs
183-
scale = get_muon_scale_factor(size_out, size_in, mode)
181+
scale = muon.get_muon_scale_factor(size_out, size_in, mode)
184182
if mode == "shape_scaling":
185183
self.assertEqual(scale, math.sqrt(max(1, size_out / size_in)))
186184
elif mode == "spectral":
@@ -196,17 +194,17 @@ def test_qkv_split_shapes_validation(self):
196194
dummy_args = dict(split_qkv=True, is_qkv_fn=lambda x: True)
197195
# Test non-integer values
198196
with self.assertRaises(ValueError) as cm:
199-
Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256))
197+
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256))
200198
self.assertIn("must be integers", str(cm.exception))
201199

202200
# Test negative values
203201
with self.assertRaises(ValueError) as cm:
204-
Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256))
202+
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256))
205203
self.assertIn("must be positive", str(cm.exception))
206204

207205
# Test wrong number of elements
208206
with self.assertRaises(ValueError) as cm:
209-
Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256))
207+
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256))
210208
self.assertIn("tuple of 3 integers", str(cm.exception))
211209

212210

0 commit comments

Comments
 (0)