Skip to content

Commit 9d4a1a4

Browse files
committed
update: ranger21
1 parent 9f55ec6 commit 9d4a1a4

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ def __init__(
102102
self.current_lr = lr
103103
self.min_lr = warm_down_min_lr
104104

105-
self.param_size: int = 0
106-
107105
defaults: DEFAULTS = dict(
108106
lr=lr,
109107
betas=betas,
@@ -241,11 +239,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
241239
variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
242240
variance_ma_sum += (variance_ma / bias_correction2).sum()
243241

244-
if self.param_size == 0:
245-
self.param_size = param_size
246-
247242
# stable weight decay
248-
variance_normalized = math.sqrt(variance_ma_sum / self.param_size)
243+
variance_normalized = math.sqrt(variance_ma_sum / param_size)
249244
if math.isnan(variance_normalized):
250245
raise RuntimeError('hit nan for variance_normalized')
251246

0 commit comments

Comments
 (0)