Skip to content

Commit 1445e85

Browse files
committed
update: test_closure
1 parent 4d21dcf commit 1445e85

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/test_optimizers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch import nn
77
from torch.nn import functional as F
8+
from pytorch_optimizer.types import LOSS
89

910
from pytorch_optimizer import (
1011
LARS,
@@ -78,6 +79,10 @@ def ids(v) -> str:
7879
return f'{v[0].__name__}_{v[1:]}'
7980

8081

82+
def dummy_closure() -> LOSS:
83+
return 1.0
84+
85+
8186
def build_lookahead(*parameters, **kwargs):
8287
return Lookahead(AdamP(*parameters, **kwargs))
8388

@@ -87,6 +92,7 @@ def build_lookahead(*parameters, **kwargs):
8792
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
8893
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'amsgrad': True}, 200),
8994
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 200),
95+
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'fixed_decay': True}, 200),
9096
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'rectify': False}, 200),
9197
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 200),
9298
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True}, 200),
@@ -140,6 +146,17 @@ def build_environment(use_gpu: bool = False) -> Tuple[Tuple[torch.Tensor, torch.
140146
return (x_data, y_data), model, loss_fn
141147

142148

149+
@pytest.mark.parametrize('optimizer_config', OPTIMIZERS, ids=ids)
150+
def test_closure(optimizer_config):
151+
_, model, _ = build_environment()
152+
153+
optimizer_class, config, _ = optimizer_config
154+
optimizer = optimizer_class(model.parameters(), **config)
155+
156+
optimizer.zero_grad()
157+
optimizer.step(closure=dummy_closure)
158+
159+
143160
@pytest.mark.parametrize('optimizer_fp32_config', OPTIMIZERS, ids=ids)
144161
def test_f32_optimizers(optimizer_fp32_config):
145162
(x_data, y_data), model, loss_fn = build_environment()

0 commit comments

Comments
 (0)