File tree Expand file tree Collapse file tree 1 file changed +1
-7
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +1
-7
lines changed Original file line number Diff line number Diff line change @@ -28,15 +28,13 @@ def __init__(
2828 lr : float = 1e-3 ,
2929 betas : BETAS = (0.1 , 0.99 ),
3030 weight_decay : float = 1e-4 ,
31- weight_decouple : bool = False ,
3231 amsgrad : bool = False ,
3332 adamd_debias_term : bool = False ,
3433 eps : float = 1e-8 ,
3534 ):
3635 self .lr = lr
3736 self .betas = betas
3837 self .weight_decay = weight_decay
39- self .weight_decouple = weight_decouple
4038 self .amsgrad = amsgrad
4139 self .adamd_debias_term = adamd_debias_term
4240 self .eps = eps
@@ -130,14 +128,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
130128 if p .grad is None :
131129 continue
132130
133- grad = p .grad
134131 state = self .state [p ]
135132
136133 if group ['weight_decay' ] > 0.0 :
137- if self .weight_decouple :
138- p .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
139- else :
140- grad .add_ (p , alpha = group ['weight_decay' ])
134+ p .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ] / exp_avg_sq_hat_mean )
141135
142136 exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
143137
You can’t perform that action at this time.
0 commit comments