Skip to content

Commit 0464b21

Browse files
committed
update: AdamS optimizer
1 parent 9a2e427 commit 0464b21

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

pytorch_optimizer/optimizer/adams.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
77
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9-
from pytorch_optimizer.optimizer.gc import centralize_gradient
109

1110

1211
class AdamS(Optimizer, BaseOptimizer):
@@ -16,7 +15,6 @@ class AdamS(Optimizer, BaseOptimizer):
1615
:param lr: float. learning rate.
1716
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
1817
:param weight_decay: float. weight decay (L2 penalty).
19-
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2018
:param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper.
2119
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
2220
:param eps: float. term added to the denominator to improve numerical stability.
@@ -138,11 +136,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
138136
bias_correction1 = 1.0 - beta1 ** state['step']
139137
bias_correction2 = 1.0 - beta2 ** state['step']
140138

141-
if self.amsgrad:
142-
max_exp_avg_sq = state['max_exp_avg_sq']
143-
exp_avg_sq_hat = max_exp_avg_sq
144-
else:
145-
exp_avg_sq_hat = exp_avg_sq
139+
exp_avg_sq_hat = state['max_exp_avg_sq'] if self.amsgrad else exp_avg_sq
146140
exp_avg_sq_hat.div_(bias_correction2)
147141

148142
de_nom = exp_avg_sq_hat.sqrt().add(group['eps'])

0 commit comments

Comments
 (0)