@@ -116,6 +116,20 @@ def build_lookahead(*parameters, **kwargs):
116116 (Ranger21 , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'num_iterations' : 500 }, 500 ),
117117]
118118
119+ ADAMD_SUPPORTED_OPTIMIZERS : List [Tuple [Any , Dict [str , Union [float , bool , int ]], int ]] = [
120+ (build_lookahead , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 500 ),
121+ (AdaBelief , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
122+ (AdaBound , {'lr' : 5e-1 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
123+ (AdaBound , {'lr' : 1e-2 , 'gamma' : 0.1 , 'weight_decay' : 1e-3 , 'amsbound' : True , 'adamd_debias_term' : True }, 200 ),
124+ (AdamP , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 500 ),
125+ (DiffGrad , {'lr' : 15 - 1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 500 ),
126+ (DiffRGrad , {'lr' : 1e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
127+ (Lamb , {'lr' : 1e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
128+ (RaLamb , {'lr' : 1e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 500 ),
129+ (RAdam , {'lr' : 1e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
130+ (Ranger , {'lr' : 5e-1 , 'weight_decay' : 1e-3 , 'adamd_debias_term' : True }, 200 ),
131+ ]
132+
119133
120134@pytest .mark .parametrize ('optimizer_fp32_config' , FP32_OPTIMIZERS , ids = ids )
121135def test_f32_optimizers (optimizer_fp32_config ):
@@ -177,16 +191,16 @@ def test_f16_optimizers(optimizer_fp16_config):
177191 assert init_loss - 0.01 > loss
178192
179193
180- @pytest .mark .parametrize ('optimizer_config ' , FP32_OPTIMIZERS , ids = ids )
181- def test_sam_optimizers (optimizer_config ):
194+ @pytest .mark .parametrize ('optimizer_sam_config ' , FP32_OPTIMIZERS , ids = ids )
195+ def test_sam_optimizers (optimizer_sam_config ):
182196 torch .manual_seed (42 )
183197
184198 x_data , y_data = make_dataset ()
185199
186200 model : nn .Module = LogisticRegression ()
187201 loss_fn : nn .Module = nn .BCEWithLogitsLoss ()
188202
189- optimizer_class , config , iterations = optimizer_config
203+ optimizer_class , config , iterations = optimizer_sam_config
190204 optimizer = SAM (model .parameters (), optimizer_class , ** config )
191205
192206 loss : float = np .inf
@@ -205,8 +219,8 @@ def test_sam_optimizers(optimizer_config):
205219 assert init_loss > 2.0 * loss
206220
207221
208- @pytest .mark .parametrize ('optimizer_config ' , FP32_OPTIMIZERS , ids = ids )
209- def test_pc_grad_optimizers (optimizer_config ):
222+ @pytest .mark .parametrize ('optimizer_pc_grad_config ' , FP32_OPTIMIZERS , ids = ids )
223+ def test_pc_grad_optimizers (optimizer_pc_grad_config ):
210224 torch .manual_seed (42 )
211225
212226 x_data , y_data = make_dataset ()
@@ -215,7 +229,7 @@ def test_pc_grad_optimizers(optimizer_config):
215229 loss_fn_1 : nn .Module = nn .BCEWithLogitsLoss ()
216230 loss_fn_2 : nn .Module = nn .L1Loss ()
217231
218- optimizer_class , config , iterations = optimizer_config
232+ optimizer_class , config , iterations = optimizer_pc_grad_config
219233 optimizer = PCGrad (optimizer_class (model .parameters (), ** config ))
220234
221235 loss : float = np .inf
@@ -233,3 +247,33 @@ def test_pc_grad_optimizers(optimizer_config):
233247 optimizer .step ()
234248
235249 assert init_loss > 2.0 * loss
250+
251+
252+ @pytest .mark .parametrize ('optimizer_adamd_config' , ADAMD_SUPPORTED_OPTIMIZERS , ids = ids )
253+ def test_adamd_optimizers (optimizer_adamd_config ):
254+ torch .manual_seed (42 )
255+
256+ x_data , y_data = make_dataset ()
257+
258+ model : nn .Module = LogisticRegression ()
259+ loss_fn : nn .Module = nn .BCEWithLogitsLoss ()
260+
261+ optimizer_class , config , iterations = optimizer_adamd_config
262+ optimizer = optimizer_class (model .parameters (), ** config )
263+
264+ loss : float = np .inf
265+ init_loss : float = np .inf
266+ for _ in range (iterations ):
267+ optimizer .zero_grad ()
268+
269+ y_pred = model (x_data )
270+ loss = loss_fn (y_pred , y_data )
271+
272+ if init_loss == np .inf :
273+ init_loss = loss
274+
275+ loss .backward ()
276+
277+ optimizer .step ()
278+
279+ assert init_loss > 2.0 * loss
0 commit comments