@@ -981,6 +981,31 @@ def test_kron_optimizer():
981981 optimizer .step ()
982982
983983
984+ def test_scion_lmo_types ():
985+ grad = torch .ones (2 , 2 )
986+
987+ expected = torch .FloatTensor ([[0.3438 , 0.3438 ], [0.3438 , 0.3438 ]]).bfloat16 ()
988+ actual = load_optimizer ('scion' ).get_lmo_direction (grad , 'spectral' )
989+
990+ torch .testing .assert_close (expected , actual , rtol = 1e-5 , atol = 1e-5 )
991+
992+ expected = torch .FloatTensor ([[0.5 , 0.5 ], [0.5 , 0.5 ]])
993+ actual = load_optimizer ('scion' ).get_lmo_direction (grad , 'sign' )
994+ torch .testing .assert_close (actual , expected , rtol = 1e-5 , atol = 1e-5 )
995+
996+ expected = torch .FloatTensor ([[0.7071 , 0.7071 ], [0.7071 , 0.7071 ]])
997+ actual = load_optimizer ('scion' ).get_lmo_direction (grad , 'row_norm' )
998+ torch .testing .assert_close (actual , expected , rtol = 1e-5 , atol = 1e-5 )
999+
1000+ expected = torch .FloatTensor ([[0.7071 , 0.7071 ], [0.7071 , 0.7071 ]])
1001+ actual = load_optimizer ('scion' ).get_lmo_direction (grad , 'col_norm' )
1002+ torch .testing .assert_close (actual , expected , rtol = 1e-5 , atol = 1e-5 )
1003+
1004+ expected = torch .FloatTensor ([[0.5 , 0.5 ], [0.5 , 0.5 ]])
1005+ actual = load_optimizer ('scion' ).get_lmo_direction (grad , 'asdf' )
1006+ torch .testing .assert_close (actual , expected , rtol = 1e-5 , atol = 1e-5 )
1007+
1008+
9841009def test_schedulefree_wrapper ():
9851010 model = Example ()
9861011
0 commit comments