@@ -182,6 +182,10 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
182182
183183 new_lr : float = self .starting_lr - self .warm_down_lr_delta * warm_down_pct
184184 new_lr = max (new_lr , self .min_lr )
185+
186+ if new_lr < 0.0 :
187+ raise NegativeLRError (new_lr )
188+
185189 self .current_lr = new_lr
186190
187191 return new_lr
@@ -249,9 +253,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
249253 for group in self .param_groups :
250254 if len (self .state ) == 0 :
251255 continue
256+
252257 p = next (iter (self .state .keys ()))
253- lr = group ["lr" ]
254- step = self .state [p ]["step" ]
258+
259+ lr = group ['lr' ]
260+ step = self .state [p ]['step' ]
255261
256262 beta1 , beta2 = group ['betas' ]
257263 bias_correction1 = 1.0 - beta1 ** step # fmt: skip
@@ -264,17 +270,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
264270
265271 # warm down
266272 lr = self .get_warm_down (lr , step )
267- if lr < 0.0 :
268- raise NegativeLRError (lr )
269273
270274 # stable decay
271275 decay = group ['weight_decay' ]
272276 if decay :
273277 p .mul_ (1.0 - decay * lr / variance_normalized )
274278
275279 # norm loss
276- u_norm = unit_norm (p )
277- correction = 2.0 * self .norm_loss_factor * (1.0 - torch .div (1 , u_norm + self .eps ))
280+ correction = 2.0 * self .norm_loss_factor * (1.0 - torch .div (1 , unit_norm (p ) + self .eps ))
278281 p .mul_ (1.0 - lr * correction )
279282
280283 for p in group ['params' ]:
0 commit comments