Skip to content

Commit 1881cdd

Browse files
committed
update: use fixture
1 parent 40bd21b commit 1881cdd

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

tests/test_optimizers.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,20 @@
2727
)
2828

2929

30+
@pytest.fixture(scope='function')
31+
def environment():
32+
return build_environment()
33+
34+
3035
@pytest.mark.parametrize('optimizer_fp32_config', OPTIMIZERS, ids=ids)
31-
def test_f32_optimizers(optimizer_fp32_config):
36+
def test_f32_optimizers(optimizer_fp32_config, environment):
3237
def closure(x):
3338
def _closure() -> float:
3439
return x
3540

3641
return _closure
3742

38-
(x_data, y_data), model, loss_fn = build_environment()
43+
(x_data, y_data), model, loss_fn = environment
3944

4045
optimizer_class, config, iterations = optimizer_fp32_config
4146

@@ -62,17 +67,14 @@ def _closure() -> float:
6267

6368
loss.backward()
6469

65-
if optimizer_name == 'AliG':
66-
optimizer.step(closure(loss))
67-
else:
68-
optimizer.step()
70+
optimizer.step(closure(loss) if optimizer_name == 'AliG' else None)
6971

7072
assert tensor_to_numpy(init_loss) > 1.5 * tensor_to_numpy(loss)
7173

7274

7375
@pytest.mark.parametrize('pullback_momentum', PULLBACK_MOMENTUM)
74-
def test_lookahead(pullback_momentum):
75-
(x_data, y_data), model, loss_fn = build_environment()
76+
def test_lookahead(pullback_momentum, environment):
77+
(x_data, y_data), model, loss_fn = environment
7678

7779
optimizer = Lookahead(load_optimizer('adamp')(model.parameters(), lr=5e-1), pullback_momentum=pullback_momentum)
7880

@@ -94,8 +96,8 @@ def test_lookahead(pullback_momentum):
9496

9597

9698
@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS)
97-
def test_sam_optimizers(adaptive):
98-
(x_data, y_data), model, loss_fn = build_environment()
99+
def test_sam_optimizers(adaptive, environment):
100+
(x_data, y_data), model, loss_fn = environment
99101

100102
optimizer = SAM(model.parameters(), load_optimizer('asgd'), lr=5e-1, adaptive=adaptive)
101103

@@ -115,8 +117,8 @@ def test_sam_optimizers(adaptive):
115117

116118

117119
@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS)
118-
def test_sam_optimizers_with_closure(adaptive):
119-
(x_data, y_data), model, loss_fn = build_environment()
120+
def test_sam_optimizers_with_closure(adaptive, environment):
121+
(x_data, y_data), model, loss_fn = environment
120122

121123
optimizer = SAM(model.parameters(), load_optimizer('adamp'), lr=5e-1, adaptive=adaptive)
122124

@@ -140,14 +142,10 @@ def closure():
140142

141143

142144
@pytest.mark.parametrize('adaptive', ADAPTIVE_FLAGS)
143-
def test_gsam_optimizers(adaptive):
145+
def test_gsam_optimizers(adaptive, environment):
144146
pytest.skip('skip GSAM optimizer')
145147

146-
(x_data, y_data), model, loss_fn = build_environment()
147-
148-
x_data = x_data.cuda()
149-
y_data = y_data.cuda()
150-
model.cuda()
148+
(x_data, y_data), model, loss_fn = environment
151149

152150
lr: float = 5e-1
153151
num_iterations: int = 25
@@ -174,8 +172,8 @@ def test_gsam_optimizers(adaptive):
174172

175173

176174
@pytest.mark.parametrize('optimizer_config', ADANORM_SUPPORTED_OPTIMIZERS, ids=ids)
177-
def test_adanorm_optimizers(optimizer_config):
178-
(x_data, y_data), model, loss_fn = build_environment()
175+
def test_adanorm_optimizers(optimizer_config, environment):
176+
(x_data, y_data), model, loss_fn = environment
179177

180178
optimizer_class, config, num_iterations = optimizer_config
181179
if optimizer_class.__name__ == 'Ranger21':
@@ -215,8 +213,8 @@ def test_adanorm_condition(optimizer_config):
215213

216214

217215
@pytest.mark.parametrize('optimizer_config', ADAMD_SUPPORTED_OPTIMIZERS, ids=ids)
218-
def test_adamd_optimizers(optimizer_config):
219-
(x_data, y_data), model, loss_fn = build_environment()
216+
def test_adamd_optimizers(optimizer_config, environment):
217+
(x_data, y_data), model, loss_fn = environment
220218

221219
optimizer_class, config, num_iterations = optimizer_config
222220
if optimizer_class.__name__ == 'Ranger21':
@@ -343,13 +341,14 @@ def test_d_adapt_reset(require_gradient, sparse_gradient, optimizer_name):
343341
param.grad = None
344342

345343
optimizer = load_optimizer(optimizer_name)([param])
346-
assert str(optimizer) == optimizer_name
347344
optimizer.reset()
348345

346+
assert str(optimizer) == optimizer_name
347+
349348

350349
@pytest.mark.parametrize('pre_conditioner_type', [0, 1, 2])
351-
def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type):
352-
(x_data, y_data), _, loss_fn = build_environment()
350+
def test_scalable_shampoo_pre_conditioner_with_svd(pre_conditioner_type, environment):
351+
(x_data, y_data), _, loss_fn = environment
353352

354353
model = nn.Sequential(
355354
nn.Linear(2, 4096),
@@ -382,5 +381,6 @@ def test_sm3_make_sparse():
382381

383382
def test_sm3_rank0():
384383
optimizer = load_optimizer('sm3')([simple_zero_rank_parameter(True)])
385-
assert str(optimizer) == 'SM3'
386384
optimizer.step()
385+
386+
assert str(optimizer) == 'SM3'

0 commit comments

Comments
 (0)