1212from pytorch_optimizer .base .exception import NoSparseGradientError
1313from pytorch_optimizer .base .optimizer import BaseOptimizer
1414from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
15- from pytorch_optimizer .optimizer .utils import to_real
15+ from pytorch_optimizer .optimizer .utils import get_global_gradient_norm , to_real
1616
1717
1818class DAdaptAdaGrad (Optimizer , BaseOptimizer ):
@@ -23,7 +23,6 @@ class DAdaptAdaGrad(Optimizer, BaseOptimizer):
2323 :param momentum: float. momentum.
2424 :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
2525 :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
26- Default is inf, for unrestricted.
2726 :param weight_decay: float. weight decay (L2 penalty).
2827 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2928 :param fixed_decay: bool. fix weight decay.
@@ -253,11 +252,10 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
253252 :param betas: BETAS. betas.
254253 :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
255254 :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
256- Default is inf, for unrestricted.
257255 :param weight_decay: float. weight decay (L2 penalty).
258256 :param weight_decouple: bool. use AdamW style weight decay.
259257 :param fixed_decay: bool. fix weight decay.
260- :param bias_correction : bool. Turn on Adam's bias correction .
258+ :param adam_debias : bool. Only correct the denominator to avoid inflating step sizes early in training .
261259 :param eps: float. term added to the denominator to improve numerical stability.
262260 """
263261
@@ -271,7 +269,7 @@ def __init__(
271269 weight_decay : float = 0.0 ,
272270 weight_decouple : bool = False ,
273271 fixed_decay : bool = False ,
274- bias_correction : bool = False ,
272+ adam_debias : bool = False ,
275273 eps : float = 0.0 ,
276274 ):
277275 self .validate_learning_rate (lr )
@@ -287,8 +285,8 @@ def __init__(
287285 'weight_decay' : weight_decay ,
288286 'weight_decouple' : weight_decouple ,
289287 'fixed_decay' : fixed_decay ,
290- 'bias_correction ' : bias_correction ,
291- 'k ' : 0 ,
288+ 'adam_debias ' : adam_debias ,
289+ 'step ' : 0 ,
292290 'eps' : eps ,
293291 }
294292 super ().__init__ (params , defaults )
@@ -299,13 +297,13 @@ def __str__(self) -> str:
299297 @torch .no_grad ()
300298 def reset (self ):
301299 for group in self .param_groups :
300+ group ['step' ] = 0
302301 for p in group ['params' ]:
303302 if p .grad is None :
304303 continue
305304
306305 state = self .state [p ]
307306
308- state ['step' ] = 0
309307 state ['s' ] = torch .zeros_like (p )
310308 state ['exp_avg' ] = torch .zeros_like (p )
311309 state ['exp_avg_sq' ] = torch .zeros_like (p )
@@ -318,26 +316,25 @@ def step(self, closure: CLOSURE = None) -> LOSS:
318316 loss = closure ()
319317
320318 group = self .param_groups [0 ]
319+ device = group ['params' ][0 ].device
321320
322321 beta1 , beta2 = group ['betas' ]
323- k : int = group ['k' ]
324322
325323 beta2_sq : float = math .sqrt (beta2 )
326324
327325 d : float = group ['d' ]
328- lr : float = max (group ['lr' ] for group in self .param_groups )
329- bias_correction : float = (
330- ((1.0 - beta2 ** (k + 1 )) ** 0.5 ) / (1.0 - beta1 ** (k + 1 )) if group ['bias_correction' ] else 1.0
331- )
332- d_lr = float (d * lr * bias_correction )
326+ lr : float = group ['lr' ]
327+
328+ bias_correction : float = 1.0 - pow (beta1 , group ['step' ] + 1 )
329+ d_lr : float = self .apply_adam_debias (group ['adam_debias' ], step_size = d * lr , bias_correction1 = bias_correction )
330+
331+ sk_l1 = torch .tensor ([0.0 ], device = device )
332+ numerator_acc = torch .tensor ([0.0 ], device = device )
333333
334334 if 'numerator_weighted' not in group :
335- group ['numerator_weighted' ] = torch .tensor ([0.0 ], device = group [ 'params' ][ 0 ]. device )
335+ group ['numerator_weighted' ] = torch .tensor ([0.0 ], device = device )
336336 numerator_weighted = group ['numerator_weighted' ]
337337
338- sk_l1 = torch .tensor ([0.0 ], device = group ['params' ][0 ].device )
339- numerator_acc = torch .tensor ([0.0 ], device = group ['params' ][0 ].device )
340-
341338 for group in self .param_groups :
342339 for p in group ['params' ]:
343340 if p .grad is None :
@@ -349,19 +346,17 @@ def step(self, closure: CLOSURE = None) -> LOSS:
349346
350347 state = self .state [p ]
351348 if 'step' not in state :
352- state ['step' ] = 0
353349 state ['s' ] = torch .zeros_like (p )
354350 state ['exp_avg' ] = torch .zeros_like (p )
355351 state ['exp_avg_sq' ] = torch .zeros_like (p )
356352
357- exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
358- s = state ['s' ]
353+ exp_avg , exp_avg_sq , s = state ['exp_avg' ], state ['exp_avg_sq' ], state ['s' ]
359354
360355 de_nom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
361356 numerator_acc .add_ (torch .dot (grad .flatten (), s .div (de_nom ).flatten ()), alpha = d_lr )
362357
363358 exp_avg .mul_ (beta1 ).add_ (grad , alpha = d_lr * (1.0 - beta1 ))
364- exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , alpha = 1.0 - beta2 )
359+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
365360
366361 s .mul_ (beta2_sq ).add_ (grad , alpha = d_lr * (1.0 - beta2_sq ))
367362
@@ -374,19 +369,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
374369
375370 if lr > 0.0 :
376371 d_hat = numerator_weighted / (1.0 - beta2_sq ) * sk_l1
377- d = max (d , min (d_hat , d * group ['growth_rate' ]))
372+ d = max (d , min (d_hat . item () , d * group ['growth_rate' ]))
378373
379374 for group in self .param_groups :
380375 group ['numerator_weighted' ] = numerator_weighted
381376 group ['d' ] = d
377+
382378 for p in group ['params' ]:
383379 if p .grad is None :
384380 continue
385381
386382 state = self .state [p ]
387383
388- state ['step' ] += 1
389-
390384 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
391385
392386 de_nom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
@@ -400,22 +394,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
400394 fixed_decay = group ['fixed_decay' ],
401395 )
402396
403- p .addcdiv_ (exp_avg , de_nom , value = - 1 )
404-
405- group ['k' ] += 1
397+ p .addcdiv_ (exp_avg , de_nom , value = - 1.0 )
406398
407399 return loss
408400
409401
410402class DAdaptSGD (Optimizer , BaseOptimizer ):
411- r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability.
403+ r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.
412404
413405 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
414406 :param lr: float. learning rate.
415407 :param momentum: float. momentum.
416408 :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
417409 :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
418- Default is inf, for unrestricted.
419410 :param weight_decay: float. weight decay (L2 penalty).
420411 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
421412 :param fixed_decay: bool. fix weight decay.
@@ -425,15 +416,15 @@ def __init__(
425416 self ,
426417 params : PARAMETERS ,
427418 lr : float = 1.0 ,
428- momentum : float = 0.0 ,
419+ momentum : float = 0.9 ,
429420 d0 : float = 1e-6 ,
430421 growth_rate : float = float ('inf' ),
431422 weight_decay : float = 0.0 ,
432423 weight_decouple : bool = False ,
433424 fixed_decay : bool = False ,
434425 ):
435426 self .validate_learning_rate (lr )
436- self .validate_range (momentum , 'momentum' , 0.0 , 1.0 )
427+ self .validate_range (momentum , 'momentum' , 0.0 , 1.0 , range_type = '[)' )
437428 self .validate_non_negative (weight_decay , 'weight_decay' )
438429
439430 defaults : DEFAULTS = {
@@ -444,7 +435,7 @@ def __init__(
444435 'weight_decay' : weight_decay ,
445436 'weight_decouple' : weight_decouple ,
446437 'fixed_decay' : fixed_decay ,
447- 'k ' : 0 ,
438+ 'step ' : 0 ,
448439 }
449440 super ().__init__ (params , defaults )
450441
@@ -454,16 +445,16 @@ def __str__(self) -> str:
454445 @torch .no_grad ()
455446 def reset (self ):
456447 for group in self .param_groups :
448+ group ['step' ] = 0
457449 for p in group ['params' ]:
458450 if p .grad is None :
459451 continue
460452
461453 state = self .state [p ]
462454
463- state ['step ' ] = 0
455+ state ['z ' ] = p . clone ()
464456 state ['s' ] = torch .zeros_like (p )
465- state ['exp_avg' ] = torch .zeros_like (p )
466- state ['exp_avg_sq' ] = torch .zeros_like (p )
457+ state ['x0' ] = p .clone ()
467458
468459 @torch .no_grad ()
469460 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -473,14 +464,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
473464 loss = closure ()
474465
475466 group = self .param_groups [0 ]
467+ device = group ['params' ][0 ].device
476468
477- growth_rate = group ['growth_rate' ]
469+ sk_sq = torch .tensor ([0.0 ], device = device )
470+ if 'numerator_weighted' not in group :
471+ group ['numerator_weighted' ] = torch .tensor ([0.0 ], device = device )
472+ numerator_weighted = group ['numerator_weighted' ]
478473
479- g_sq = torch .tensor ([0.0 ], device = group ['params' ][0 ].device )
480- sk_sq = torch .tensor ([0.0 ], device = group ['params' ][0 ].device )
481- if 'gsq_weighted' not in group :
482- group ['gsq_weighted' ] = torch .tensor ([0.0 ], device = group ['params' ][0 ].device )
483- gsq_weighted = group ['gsq_weighted' ]
474+ if group ['step' ] == 0 :
475+ group ['g0_norm' ] = get_global_gradient_norm (self .param_groups , device ).sqrt_ ().item ()
476+ g0_norm = group ['g0_norm' ]
477+
478+ if g0_norm == 0 :
479+ return loss
480+
481+ d , lr = group ['d' ], group ['lr' ]
482+ d_lr : float = d * lr / g0_norm
484483
485484 for group in self .param_groups :
486485 for p in group ['params' ]:
@@ -491,57 +490,39 @@ def step(self, closure: CLOSURE = None) -> LOSS:
491490 if grad .is_sparse :
492491 raise NoSparseGradientError (str (self ))
493492
493+ state = self .state [p ]
494+ if len (state ) == 0 :
495+ state ['z' ] = p .clone ()
496+ state ['s' ] = torch .zeros_like (p )
497+ state ['x0' ] = p .clone ()
498+
494499 self .apply_weight_decay (
495500 p = p ,
496- grad = grad ,
497- lr = group [ 'lr' ] ,
501+ grad = None ,
502+ lr = d_lr ,
498503 weight_decay = group ['weight_decay' ],
499504 weight_decouple = group ['weight_decouple' ],
500505 fixed_decay = group ['fixed_decay' ],
501506 )
502507
503- state = self .state [p ]
504- if 'z' not in state :
505- state ['z' ] = torch .clone (p )
506- state ['s' ] = torch .zeros_like (p )
507- state ['x0' ] = torch .clone (p )
508-
509- g_sq .add_ (grad .pow (2 ).sum ())
510-
511- if g_sq == 0 :
512- return loss
513-
514- group = self .param_groups [0 ]
515-
516- if group ['k' ] == 0 :
517- group ['g0_norm' ] = g_sq .sqrt ().item ()
518- g0_norm = group ['g0_norm' ]
519-
520- d , lr = group ['d' ], group ['lr' ]
521- d_lr = float (d * lr ) / g0_norm
522-
523- for group in self .param_groups :
524- for p in group ['params' ]:
525- if p .grad is None :
526- continue
527-
528- state = self .state [p ]
529-
530508 s = state ['s' ]
531- s .add_ (p . grad , alpha = d_lr )
509+ numerator_weighted .add_ (torch . dot ( grad . flatten (), s . flatten ()) , alpha = d_lr )
532510
511+ s .add_ (grad , alpha = d_lr )
533512 sk_sq .add_ (s .pow (2 ).sum ())
534513
535- gsq_weighted .add_ (g_sq , alpha = d_lr ** 2 ) # fmt: skip
536-
537514 if lr > 0.0 :
538- d_hat = ( sk_sq - gsq_weighted ) / sk_sq .sqrt ()
539- d = max (d , min (d_hat , d * growth_rate ))
515+ d_hat = 2.0 * numerator_weighted / sk_sq .sqrt ()
516+ d = max (d , min (d_hat . item () , d * group [ ' growth_rate' ] ))
540517
541518 for group in self .param_groups :
542- group ['gsq_weighted' ] = gsq_weighted
519+ if 'step' in group :
520+ group ['step' ] += 1
521+ else :
522+ group ['step' ] = 1
523+
524+ group ['numerator_weighted' ] = numerator_weighted
543525 group ['d' ] = d
544- group ['g0_norm' ] = g0_norm
545526
546527 for p in group ['params' ]:
547528 if p .grad is None :
@@ -554,8 +535,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
554535
555536 p .mul_ (group ['momentum' ]).add_ (z , alpha = 1.0 - group ['momentum' ])
556537
557- group ['k' ] += 1
558-
559538 return loss
560539
561540
0 commit comments