55import torch
66from torch import nn
77from torch .nn import functional as F
8+ from pytorch_optimizer .types import LOSS
89
910from pytorch_optimizer import (
1011 LARS ,
@@ -78,6 +79,10 @@ def ids(v) -> str:
7879 return f'{ v [0 ].__name__ } _{ v [1 :]} '
7980
8081
82+ def dummy_closure () -> LOSS :
83+ return 1.0
84+
85+
8186def build_lookahead (* parameters , ** kwargs ):
8287 return Lookahead (AdamP (* parameters , ** kwargs ))
8388
@@ -87,6 +92,7 @@ def build_lookahead(*parameters, **kwargs):
8792 (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 200 ),
8893 (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'amsgrad' : True }, 200 ),
8994 (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'weight_decouple' : False }, 200 ),
95+ (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'fixed_decay' : True }, 200 ),
9096 (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'rectify' : False }, 200 ),
9197 (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 }, 200 ),
9298 (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'amsbound' : True }, 200 ),
@@ -140,6 +146,17 @@ def build_environment(use_gpu: bool = False) -> Tuple[Tuple[torch.Tensor, torch.
140146 return (x_data , y_data ), model , loss_fn
141147
142148
149+ @pytest .mark .parametrize ('optimizer_config' , OPTIMIZERS , ids = ids )
150+ def test_closure (optimizer_config ):
151+ _ , model , _ = build_environment ()
152+
153+ optimizer_class , config , _ = optimizer_config
154+ optimizer = optimizer_class (model .parameters (), ** config )
155+
156+ optimizer .zero_grad ()
157+ optimizer .step (closure = dummy_closure )
158+
159+
143160@pytest .mark .parametrize ('optimizer_fp32_config' , OPTIMIZERS , ids = ids )
144161def test_f32_optimizers (optimizer_fp32_config ):
145162 (x_data , y_data ), model , loss_fn = build_environment ()
0 commit comments