Skip to content

Commit 9f55ec6

Browse files
committed
update: ranger21
1 parent 1445e85 commit 9f55ec6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,11 +241,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
241241
variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
242242
variance_ma_sum += (variance_ma / bias_correction2).sum()
243243

244-
if not self.param_size:
244+
if self.param_size == 0:
245245
self.param_size = param_size
246246

247247
# stable weight decay
248-
variance_normalized = math.sqrt(variance_ma_sum / param_size)
248+
variance_normalized = math.sqrt(variance_ma_sum / self.param_size)
249249
if math.isnan(variance_normalized):
250250
raise RuntimeError('hit nan for variance_normalized')
251251

0 commit comments

Comments
 (0)