Skip to content

Commit 8fb1be6

Browse files
committed
refactor: Ranger21
1 parent 05c737d commit 8fb1be6

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
481481
p.data.mul_(1 - decay * _lambda / variance_normalized)
482482

483483
if self.norm_loss_active:
484-
unorm = unit_norm(p.data)
484+
u_norm = unit_norm(p.data)
485485
correction = (
486486
2
487487
* self.norm_loss_factor
488-
* (1 - torch.div(1, unorm + self.eps))
488+
* (1 - torch.div(1, u_norm + self.eps))
489489
)
490490
p.mul_(1 - lr * correction)
491491

@@ -534,24 +534,20 @@ def step(self, closure: CLOSURE = None) -> LOSS:
534534
if self.softplus:
535535
rms = F.softplus(rms, beta=self.beta_softplus)
536536

537-
# Update s
538537
s.data.add_(inner_grad, alpha=_lambda)
539538

540-
# Step
541539
if momentum == 0:
542540
p.data.copy_(x0.addcdiv(s, rms, value=-1))
543541
else:
544542
z = x0.addcdiv(s, rms, value=-1)
545543

546544
# p is a moving average of z
547545
p.data.mul_(1 - ck).add_(z, alpha=ck)
548-
else: # adam with pnm core
546+
else:
549547
grad = p.grad
550548
beta1, beta2 = group['betas']
551549
grad_ma = state['grad_ma']
552550
variance_ma = state['variance_ma']
553-
if self.use_adabelief:
554-
variance_ma_belief = state['variance_ma_belief']
555551

556552
if self.momentum_pnm:
557553
max_variance_ma = state['max_variance_ma']

0 commit comments

Comments
 (0)