66from pytorch_optimizer .base .exception import NoSparseGradientError , ZeroParameterSizeError
77from pytorch_optimizer .base .optimizer import BaseOptimizer
88from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
9- from pytorch_optimizer .optimizer .gc import centralize_gradient
109
1110
1211class AdamS (Optimizer , BaseOptimizer ):
@@ -16,7 +15,6 @@ class AdamS(Optimizer, BaseOptimizer):
1615 :param lr: float. learning rate.
1716 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1817 :param weight_decay: float. weight decay (L2 penalty).
19- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2018 :param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper.
2119 :param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
2220 :param eps: float. term added to the denominator to improve numerical stability.
@@ -138,11 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
138136 bias_correction1 = 1.0 - beta1 ** state ['step' ]
139137 bias_correction2 = 1.0 - beta2 ** state ['step' ]
140138
141- if self .amsgrad :
142- max_exp_avg_sq = state ['max_exp_avg_sq' ]
143- exp_avg_sq_hat = max_exp_avg_sq
144- else :
145- exp_avg_sq_hat = exp_avg_sq
139+ exp_avg_sq_hat = state ['max_exp_avg_sq' ] if self .amsgrad else exp_avg_sq
146140 exp_avg_sq_hat .div_ (bias_correction2 )
147141
148142 de_nom = exp_avg_sq_hat .sqrt ().add (group ['eps' ])
0 commit comments