Skip to content

Commit 9301d51

Browse files
committed
update: test_adamd_optimizers
1 parent b905938 commit 9301d51

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

tests/test_optimizers.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def build_lookahead(*parameters, **kwargs):
116116
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 500),
117117
]
118118

119+
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
120+
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 500),
121+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
122+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
123+
(AdaBound, {'lr': 1e-2, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True, 'adamd_debias_term': True}, 200),
124+
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 500),
125+
(DiffGrad, {'lr': 15 - 1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 500),
126+
(DiffRGrad, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
127+
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
128+
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 500),
129+
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
130+
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 200),
131+
]
132+
119133

120134
@pytest.mark.parametrize('optimizer_fp32_config', FP32_OPTIMIZERS, ids=ids)
121135
def test_f32_optimizers(optimizer_fp32_config):
@@ -177,16 +191,16 @@ def test_f16_optimizers(optimizer_fp16_config):
177191
assert init_loss - 0.01 > loss
178192

179193

180-
@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
181-
def test_sam_optimizers(optimizer_config):
194+
@pytest.mark.parametrize('optimizer_sam_config', FP32_OPTIMIZERS, ids=ids)
195+
def test_sam_optimizers(optimizer_sam_config):
182196
torch.manual_seed(42)
183197

184198
x_data, y_data = make_dataset()
185199

186200
model: nn.Module = LogisticRegression()
187201
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
188202

189-
optimizer_class, config, iterations = optimizer_config
203+
optimizer_class, config, iterations = optimizer_sam_config
190204
optimizer = SAM(model.parameters(), optimizer_class, **config)
191205

192206
loss: float = np.inf
@@ -205,8 +219,8 @@ def test_sam_optimizers(optimizer_config):
205219
assert init_loss > 2.0 * loss
206220

207221

208-
@pytest.mark.parametrize('optimizer_config', FP32_OPTIMIZERS, ids=ids)
209-
def test_pc_grad_optimizers(optimizer_config):
222+
@pytest.mark.parametrize('optimizer_pc_grad_config', FP32_OPTIMIZERS, ids=ids)
223+
def test_pc_grad_optimizers(optimizer_pc_grad_config):
210224
torch.manual_seed(42)
211225

212226
x_data, y_data = make_dataset()
@@ -215,7 +229,7 @@ def test_pc_grad_optimizers(optimizer_config):
215229
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
216230
loss_fn_2: nn.Module = nn.L1Loss()
217231

218-
optimizer_class, config, iterations = optimizer_config
232+
optimizer_class, config, iterations = optimizer_pc_grad_config
219233
optimizer = PCGrad(optimizer_class(model.parameters(), **config))
220234

221235
loss: float = np.inf
@@ -233,3 +247,33 @@ def test_pc_grad_optimizers(optimizer_config):
233247
optimizer.step()
234248

235249
assert init_loss > 2.0 * loss
250+
251+
252+
@pytest.mark.parametrize('optimizer_adamd_config', ADAMD_SUPPORTED_OPTIMIZERS, ids=ids)
253+
def test_adamd_optimizers(optimizer_adamd_config):
254+
torch.manual_seed(42)
255+
256+
x_data, y_data = make_dataset()
257+
258+
model: nn.Module = LogisticRegression()
259+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
260+
261+
optimizer_class, config, iterations = optimizer_adamd_config
262+
optimizer = optimizer_class(model.parameters(), **config)
263+
264+
loss: float = np.inf
265+
init_loss: float = np.inf
266+
for _ in range(iterations):
267+
optimizer.zero_grad()
268+
269+
y_pred = model(x_data)
270+
loss = loss_fn(y_pred, y_data)
271+
272+
if init_loss == np.inf:
273+
init_loss = loss
274+
275+
loss.backward()
276+
277+
optimizer.step()
278+
279+
assert init_loss > 2.0 * loss

0 commit comments

Comments
 (0)