Skip to content

Commit e898142

Browse files
committed
Fix ranger 21 update phase
1 parent 373e1b5 commit e898142

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
247247

248248
# Phase 2 - Apply weight decay and step
249249
for group in self.param_groups:
250-
p = group['params'][0]
251-
if p.grad is None:
250+
if len(self.state) == 0:
252251
continue
253-
254-
lr = group['lr']
255-
step = self.state[group['params'][0]]['step']
252+
p = next(iter(self.state.keys()))
253+
lr = group["lr"]
254+
step = self.state[p]["step"]
256255

257256
beta1, beta2 = group['betas']
258257
bias_correction1 = 1.0 - beta1 ** step # fmt: skip

0 commit comments

Comments
 (0)