We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5a0cfac commit 54b8919Copy full SHA for 54b8919
tests/test_optimizers.py
@@ -988,6 +988,11 @@ def test_build_lmo_types(lmo_type):
988
989
990
def test_scion_lmo_types():
991
+ model = Example()
992
+
993
+ load_optimizer('scion')(model.parameters()).init()
994
+ load_optimizer('scionlight')(model.parameters()).init()
995
996
grad_1d = torch.ones(1)
997
grad_2d = torch.ones(1, 1)
998
grad_4d = torch.ones(1, 1, 1, 1)
@@ -1012,6 +1017,10 @@ def test_scion_lmo_types():
1012
1017
norm.init(grad_2d)
1013
1018
norm.lmo(grad_2d)
1014
1019
1020
+ norm = build_lmo_norm(norm_type=4, zero_init=False)
1021
+ norm.init(grad_2d)
1022
+ norm.lmo(grad_2d)
1023
1015
1024
norm = build_lmo_norm(norm_type=6, normalized=True, transpose=True)
1016
1025
1026
0 commit comments