@@ -86,7 +86,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8686 for group in self .param_groups :
8787 eps = group ['eps' ]
8888 lr = group ['lr' ] + eps
89- decay = group ['weight_decay' ]
89+ weight_decay = group ['weight_decay' ]
9090 momentum = group ['momentum' ]
9191
9292 ck : float = 1.0 - momentum
@@ -111,15 +111,15 @@ def step(self, closure: CLOSURE = None) -> LOSS:
111111 grad_sum_sq = state ['grad_sum_sq' ]
112112 s = state ['s' ]
113113
114- if decay != 0 and not self .decouple_decay :
114+ if weight_decay > 0. 0 and not self .decouple_decay :
115115 if grad .is_sparse :
116116 raise NoSparseGradientError (self .__name__ , note = 'weight_decay' )
117117
118118 # original implementation
119- grad .add_ (p , alpha = decay )
119+ grad .add_ (p , alpha = weight_decay )
120120
121121 # Apply weight decay - L2 / AdamW style
122- # p.mul_(1.0 - lr * decay )
122+ # p.mul_(1.0 - lr * weight_decay )
123123
124124 if grad .is_sparse :
125125 grad = grad .coalesce ()
@@ -167,7 +167,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
167167
168168 s .add_ (grad , alpha = _lambda )
169169
170- if decay != 0 and self .decouple_decay :
170+ if weight_decay > 0. 0 and self .decouple_decay :
171171 p_old = p .clone ()
172172
173173 if momentum == 0.0 :
@@ -176,8 +176,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176176 z = x0 .addcdiv (s , rms , value = - 1 )
177177 p .mul_ (1.0 - ck ).add_ (z , alpha = ck )
178178
179- if decay != 0 and self .decouple_decay :
180- p .add_ (p_old , alpha = - lr * decay )
179+ if weight_decay > 0. 0 and self .decouple_decay :
180+ p .add_ (p_old , alpha = - lr * weight_decay )
181181
182182 self .state ['k' ] += 1
183183
0 commit comments