Skip to content

Commit 54b8919

Browse files
committed
update: test_scion_lmo_types
1 parent 5a0cfac commit 54b8919

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/test_optimizers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,11 @@ def test_build_lmo_types(lmo_type):
988988

989989

990990
def test_scion_lmo_types():
991+
model = Example()
992+
993+
load_optimizer('scion')(model.parameters()).init()
994+
load_optimizer('scionlight')(model.parameters()).init()
995+
991996
grad_1d = torch.ones(1)
992997
grad_2d = torch.ones(1, 1)
993998
grad_4d = torch.ones(1, 1, 1, 1)
@@ -1012,6 +1017,10 @@ def test_scion_lmo_types():
10121017
norm.init(grad_2d)
10131018
norm.lmo(grad_2d)
10141019

1020+
norm = build_lmo_norm(norm_type=4, zero_init=False)
1021+
norm.init(grad_2d)
1022+
norm.lmo(grad_2d)
1023+
10151024
norm = build_lmo_norm(norm_type=6, normalized=True, transpose=True)
10161025
norm.init(grad_2d)
10171026
norm.lmo(grad_2d)

0 commit comments

Comments
 (0)