Skip to content

Commit a351a4d

Browse files
committed
update: test_scion_lmo_types
1 parent 4f62e30 commit a351a4d

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/test_optimizers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,31 @@ def test_kron_optimizer():
981981
optimizer.step()
982982

983983

984+
def test_scion_lmo_types():
985+
grad = torch.ones(2, 2)
986+
987+
expected = torch.FloatTensor([[0.3438, 0.3438], [0.3438, 0.3438]]).bfloat16()
988+
actual = load_optimizer('scion').get_lmo_direction(grad, 'spectral')
989+
990+
torch.testing.assert_close(expected, actual, rtol=1e-5, atol=1e-5)
991+
992+
expected = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]])
993+
actual = load_optimizer('scion').get_lmo_direction(grad, 'sign')
994+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
995+
996+
expected = torch.FloatTensor([[0.7071, 0.7071], [0.7071, 0.7071]])
997+
actual = load_optimizer('scion').get_lmo_direction(grad, 'row_norm')
998+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
999+
1000+
expected = torch.FloatTensor([[0.7071, 0.7071], [0.7071, 0.7071]])
1001+
actual = load_optimizer('scion').get_lmo_direction(grad, 'col_norm')
1002+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
1003+
1004+
expected = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]])
1005+
actual = load_optimizer('scion').get_lmo_direction(grad, 'asdf')
1006+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
1007+
1008+
9841009
def test_schedulefree_wrapper():
9851010
model = Example()
9861011

0 commit comments

Comments
 (0)