Skip to content

Commit 8ff7eab

Browse files
committed
update: test cases
1 parent e8802c8 commit 8ff7eab

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

tests/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@
564564
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
565565
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
566566
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
567-
(SCION, {'lr': 5e-1, 'constraint': False, 'weight_decay': 1e-3}, 5),
567+
(SCION, {'lr': 5e-1, 'constraint': False, 'weight_decay': 1e-3}, 10),
568568
(SCION, {'lr': 1e-1, 'constraint': True}, 10),
569569
(Ranger25, {'lr': 1e-1}, 3),
570570
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),

tests/test_optimizer_parameters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,11 @@ def test_load_wrapper_optimizer(optimizer_instance):
303303

304304
state = optimizer.state_dict()
305305
optimizer.load_state_dict(state)
306+
307+
308+
def test_scion_lmo_direction():
309+
x = torch.zeros((1, 1), dtype=torch.float32)
310+
311+
optimizer_instance = load_optimizer('SCION')
312+
for lmo_direction in ('spectral', 'sign', 'col_norm', 'row_norm'):
313+
optimizer_instance.get_lmo_direction(x, lmo_direction)

0 commit comments

Comments
 (0)