Skip to content

Commit 5326483

Browse files
committed
update: test_schedule_free_methods
1 parent e1ab493 commit 5326483

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

tests/test_optimizers.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -639,18 +639,12 @@ def test_dynamic_scaler():
639639
scaler.update_scale(overflow=False)
640640

641641

642-
def test_schedule_free_train_mode():
643-
param = simple_parameter(True)
644-
645-
opt = load_optimizer('ScheduleFreeAdamW')([param])
646-
opt.reset()
647-
opt.eval()
648-
opt.train()
649-
650-
opt = load_optimizer('ScheduleFreeSGD')([param])
651-
opt.reset()
652-
opt.eval()
653-
opt.train()
642+
@pytest.mark.parametrize('optimizer_name', ['ScheduleFreeAdamW', 'ScheduleFreeSGD', 'ScheduleFreeRAdam'])
643+
def test_schedule_free_methods(optimizer_name):
644+
optimizer = load_optimizer(optimizer_name)([simple_parameter(True)])
645+
optimizer.reset()
646+
optimizer.eval()
647+
optimizer.train()
654648

655649

656650
@pytest.mark.parametrize('filter_type', ['mean', 'sum'])

0 commit comments

Comments
 (0)