@@ -200,6 +200,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
200200
201201 # Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
202202 for group in self .param_groups :
203+ if 'step' in group :
204+ group ['step' ] += 1
205+ else :
206+ group ['step' ] = 1
207+
203208 beta1 , beta2 = group ['betas' ]
204209 for p in group ['params' ]:
205210 if p .grad is None :
@@ -216,7 +221,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
216221
217222 state = self .state [p ]
218223 if len (state ) == 0 :
219- state ['step' ] = 0
220224 state ['grad_ma' ] = torch .zeros_like (p )
221225 state ['variance_ma' ] = torch .zeros_like (p )
222226 state ['lookahead_params' ] = torch .empty_like (p )
@@ -229,17 +233,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
229233 grad = centralize_gradient (grad , gc_conv_only = False )
230234 grad = normalize_gradient (grad )
231235
232- state ['step' ] += 1
233-
234- bias_correction2 = 1.0 - beta2 ** state ['step' ]
236+ bias_correction2 = 1.0 - beta2 ** group ['step' ]
235237
236238 # second moment estimation
237239 # using positive-negative momentum and bias correction
238240 variance_ma = state ['variance_ma' ]
239241 variance_ma .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
240242 variance_ma_sum += (variance_ma / bias_correction2 ).sum ()
241243
242- # stable weight decay
243244 if param_size == 0 :
244245 raise ZeroParameterSizeError ()
245246
@@ -249,41 +250,32 @@ def step(self, closure: CLOSURE = None) -> LOSS:
249250
250251 # Phase 2 - Apply weight decay and step
251252 for group in self .param_groups :
252- if len (self .state ) == 0 :
253- continue
254-
255- p = next (iter (self .state .keys ()))
256-
257253 lr = group ['lr' ]
258- step = self .state [p ]['step' ]
259-
260254 beta1 , beta2 = group ['betas' ]
261- bias_correction1 = 1.0 - beta1 ** step # fmt: skip
262- bias_correction2_sq = math .sqrt (1.0 - beta2 ** step ) # fmt: skip
263255
264- noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
256+ bias_correction1 = 1.0 - beta1 ** group ['step' ] # fmt: skip
257+ bias_correction2_sq = math .sqrt (1.0 - beta2 ** group ['step' ]) # fmt: skip
265258
266- # warm up
267- lr = self .warm_up_dampening (lr , step )
268-
269- # warm down
270- lr = self .warm_down (lr , step )
271-
272- # stable decay
273- decay = group ['weight_decay' ]
274- if decay :
275- p .mul_ (1.0 - decay * lr / variance_normalized )
259+ noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
276260
277- # norm loss
278- correction = 2.0 * self .norm_loss_factor * ( 1.0 - torch . div ( 1 , unit_norm ( p ) + self . eps ) )
279- p . mul_ ( 1.0 - lr * correction )
261+ # warm up & down
262+ lr = self .warm_up_dampening ( lr , group [ 'step' ] )
263+ lr = self . warm_down ( lr , group [ 'step' ] )
280264
281265 for p in group ['params' ]:
282266 if p .grad is None :
283267 continue
284268
269+ # stable weight decay
270+ if group ['weight_decay' ]:
271+ p .mul_ (1.0 - group ['weight_decay' ] * lr / variance_normalized )
272+
273+ # norm loss
274+ correction = 2.0 * self .norm_loss_factor * (1.0 - torch .div (1 , unit_norm (p ) + self .eps ))
275+ p .mul_ (1.0 - lr * correction )
276+
285277 state = self .state [p ]
286- if state ['step' ] % 2 == 1 :
278+ if group ['step' ] % 2 == 1 :
287279 grad_ma , neg_grad_ma = state ['grad_ma' ], state ['neg_grad_ma' ]
288280 else :
289281 grad_ma , neg_grad_ma = state ['neg_grad_ma' ], state ['grad_ma' ]
0 commit comments