Skip to content

Commit 87c17b1

Browse files
committed
update: variance_ma_sum to 1.0 to prevent division by zero exception
1 parent ab20924 commit 87c17b1

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pytorch_optimizer/ranger21.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,26 @@ def __init__(
144144
self.total_iterations: int = num_epochs * num_batches_per_epoch
145145
if not self.total_iterations:
146146
raise ValueError(
147-
'missing total iterations, which is calculated from num epochs and num iterations per epoch param'
147+
'missing total iterations, '
148+
'calculated from num epochs and num iterations per epoch param'
148149
)
149150

150-
# lr
151151
self.starting_lr = lr
152152
self.current_lr = lr
153153

154-
# warmup - we'll use default recommended in Ma/Yarats unless user specifies num iterations
154+
# warmup - we'll use default recommended in Ma/Yarats
155+
# unless user specifies num iterations
155156
self.use_warmup = use_warmup
156-
self.warmup_complete = False
157157
self.warmup_type = warmup_type
158158
self.warmup_pct_default = warmup_pct_default
159+
self.warmup_complete: bool = False
159160

160161
if num_warmup_iterations is None:
161162
beta_warmup_iterations: int = math.ceil((2 / (1 - betas[1])))
162163
beta_pct: float = beta_warmup_iterations / self.total_iterations
163164

164-
# this can be unreasonable for short runs...so let's compare vs warmup pct % of total epochs
165+
# this can be unreasonable for short runs...
166+
# so let's compare vs warmup pct % of total epochs
165167
if beta_pct > 0.45:
166168
warmup_auto_pct = int(
167169
self.warmup_pct_default * self.total_iterations
@@ -351,7 +353,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
351353
loss = closure()
352354

353355
param_size: float = 0
354-
variance_ma_sum: float = 0.0
356+
variance_ma_sum: float = 1.0
355357

356358
# phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay
357359
for i, group in enumerate(self.param_groups):

0 commit comments

Comments
 (0)