Skip to content

Commit f8aa431

Browse files
committed
refactor: test_d_adapt_no_progress
1 parent 97f8ee8 commit f8aa431

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tests/test_gradients.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ def test_sam_no_gradient():
133133
optimizer.second_step(zero_grad=True)
134134

135135

136+
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD', 'DAdaptAdan'])
137+
def test_d_adapt_no_progress(optimizer_name):
138+
param = simple_parameter(True)
139+
param.grad = None
140+
141+
optimizer = load_optimizer(optimizer_name)([param])
142+
optimizer.zero_grad()
143+
optimizer.step()
144+
145+
136146
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD', 'DAdaptAdan'])
137147
def test_d_adapt_2nd_stage_gradient(optimizer_name):
138148
p1 = simple_parameter(require_grad=False)

tests/test_optimizers.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,6 @@ def test_d_adapt_reset(require_gradient, sparse_gradient, optimizer_name):
359359
optimizer.reset()
360360

361361

362-
@pytest.mark.parametrize('optimizer_name', ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD', 'DAdaptAdan'])
363-
def test_d_adapt_no_progress(optimizer_name):
364-
param = simple_parameter(True)
365-
param.grad = None
366-
367-
optimizer = load_optimizer(optimizer_name)([param])
368-
optimizer.zero_grad()
369-
optimizer.step()
370-
371-
372362
@pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2])
373363
def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type):
374364
(x_data, y_data), _, loss_fn = build_environment()

0 commit comments

Comments
 (0)