88
99from pytorch_optimizer import (
1010 MADGRAD ,
11+ SAM ,
1112 SGDP ,
1213 AdaBelief ,
1314 AdaBound ,
1920 RAdam ,
2021 Ranger ,
2122 Ranger21 ,
23+ SafeFP16Optimizer ,
2224)
2325
2426__REFERENCE__ = 'https://github.com/jettify/pytorch-optimizer/blob/master/tests/test_optimizer_with_nn.py'
@@ -66,7 +68,7 @@ def build_lookahead(*parameters, **kwargs):
6668 return Lookahead (AdamP (* parameters , ** kwargs ))
6769
6870
69- OPTIMIZERS : List [Tuple [Any , Dict [str , Union [float , bool , int ]], int ]] = [
71+ FP32_OPTIMIZERS : List [Tuple [Any , Dict [str , Union [float , bool , int ]], int ]] = [
7072 (build_lookahead , {'lr' : 1e-2 , 'weight_decay' : 1e-3 }, 200 ),
7173 (AdaBelief , {'lr' : 1e-2 , 'weight_decay' : 1e-3 }, 200 ),
7274 (AdaBound , {'lr' : 1e-2 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 }, 200 ),
@@ -78,21 +80,34 @@ def build_lookahead(*parameters, **kwargs):
7880 (RAdam , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
7981 (SGDP , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
8082 (Ranger , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
81- (Ranger21 , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'num_iterations' : 1000 }, 500 ),
82- # (AdaHessian, {'lr': 1e-2, 'weight_decay': 1e-3}, 200),
83+ (Ranger21 , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'num_iterations' : 500 }, 500 ),
8384]
8485
86+ FP16_OPTIMIZERS : List [Tuple [Any , Dict [str , Union [float , bool , int ]], int ]] = [
87+ (build_lookahead , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 500 ),
88+ (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 200 ),
89+ (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 }, 200 ),
90+ (AdamP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 500 ),
91+ (DiffGrad , {'lr' : 15 - 1 , 'weight_decay' : 1e-3 }, 500 ),
92+ (DiffRGrad , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
93+ (Lamb , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
94+ (RAdam , {'lr' : 1e-1 , 'weight_decay' : 1e-3 }, 200 ),
95+ (SGDP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 500 ),
96+ (Ranger , {'lr' : 5e-1 , 'weight_decay' : 1e-3 }, 200 ),
97+ (Ranger21 , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'num_iterations' : 500 }, 500 ),
98+ ]
8599
86- @pytest .mark .parametrize ('optimizer_config' , OPTIMIZERS , ids = ids )
87- def test_optimizers (optimizer_config ):
100+
101+ @pytest .mark .parametrize ('optimizer_fp32_config' , FP32_OPTIMIZERS , ids = ids )
102+ def test_f32_optimizers (optimizer_fp32_config ):
88103 torch .manual_seed (42 )
89104
90105 x_data , y_data = make_dataset ()
91106
92107 model : nn .Module = LogisticRegression ()
93108 loss_fn : nn .Module = nn .BCEWithLogitsLoss ()
94109
95- optimizer_class , config , iterations = optimizer_config
110+ optimizer_class , config , iterations = optimizer_fp32_config
96111 optimizer = optimizer_class (model .parameters (), ** config )
97112
98113 loss : float = np .inf
@@ -111,3 +126,58 @@ def test_optimizers(optimizer_config):
111126 optimizer .step ()
112127
113128 assert init_loss > 2.0 * loss
129+
130+
131+ @pytest .mark .parametrize ('optimizer_fp16_config' , FP16_OPTIMIZERS , ids = ids )
132+ def test_f16_optimizers (optimizer_fp16_config ):
133+ torch .manual_seed (42 )
134+
135+ x_data , y_data = make_dataset ()
136+
137+ model : nn .Module = LogisticRegression ()
138+ loss_fn : nn .Module = nn .BCEWithLogitsLoss ()
139+
140+ optimizer_class , config , iterations = optimizer_fp16_config
141+ optimizer = SafeFP16Optimizer (optimizer_class (model .parameters (), ** config ))
142+
143+ loss : float = np .inf
144+ init_loss : float = np .inf
145+ for _ in range (1000 ):
146+ optimizer .zero_grad ()
147+
148+ y_pred = model (x_data )
149+ loss = loss_fn (y_pred , y_data )
150+
151+ if init_loss == np .inf :
152+ init_loss = loss
153+
154+ loss .backward ()
155+
156+ optimizer .step ()
157+
158+ assert init_loss - 0.01 > loss
159+
160+
161+ @pytest .mark .parametrize ('optimizer_config' , FP32_OPTIMIZERS , ids = ids )
162+ def test_sam_optimizers (optimizer_config ):
163+ torch .manual_seed (42 )
164+
165+ x_data , y_data = make_dataset ()
166+
167+ model : nn .Module = LogisticRegression ()
168+ loss_fn : nn .Module = nn .BCEWithLogitsLoss ()
169+
170+ optimizer_class , config , iterations = optimizer_config
171+ optimizer = SAM (model .parameters (), optimizer_class , ** config )
172+
173+ loss : float = np .inf
174+ init_loss : float = np .inf
175+ for _ in range (iterations ):
176+ loss = loss_fn (y_data , model (x_data ))
177+ loss .backward ()
178+ optimizer .first_step (zero_grad = True )
179+
180+ loss_fn (y_data , model (x_data )).backward ()
181+ optimizer .second_step (zero_grad = True )
182+
183+ assert init_loss > 2.0 * loss
0 commit comments