@@ -481,11 +481,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
481481 p .data .mul_ (1 - decay * _lambda / variance_normalized )
482482
483483 if self .norm_loss_active :
484- unorm = unit_norm (p .data )
484+ u_norm = unit_norm (p .data )
485485 correction = (
486486 2
487487 * self .norm_loss_factor
488- * (1 - torch .div (1 , unorm + self .eps ))
488+ * (1 - torch .div (1 , u_norm + self .eps ))
489489 )
490490 p .mul_ (1 - lr * correction )
491491
@@ -534,24 +534,20 @@ def step(self, closure: CLOSURE = None) -> LOSS:
534534 if self .softplus :
535535 rms = F .softplus (rms , beta = self .beta_softplus )
536536
537- # Update s
538537 s .data .add_ (inner_grad , alpha = _lambda )
539538
540- # Step
541539 if momentum == 0 :
542540 p .data .copy_ (x0 .addcdiv (s , rms , value = - 1 ))
543541 else :
544542 z = x0 .addcdiv (s , rms , value = - 1 )
545543
546544 # p is a moving average of z
547545 p .data .mul_ (1 - ck ).add_ (z , alpha = ck )
548- else : # adam with pnm core
546+ else :
549547 grad = p .grad
550548 beta1 , beta2 = group ['betas' ]
551549 grad_ma = state ['grad_ma' ]
552550 variance_ma = state ['variance_ma' ]
553- if self .use_adabelief :
554- variance_ma_belief = state ['variance_ma_belief' ]
555551
556552 if self .momentum_pnm :
557553 max_variance_ma = state ['max_variance_ma' ]
0 commit comments