Skip to content

Commit b3f7b42

Browse files
committed
refactor: get_warm_down
1 parent 84a24df commit b3f7b42

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
182182

183183
new_lr: float = self.starting_lr - self.warm_down_lr_delta * warm_down_pct
184184
new_lr = max(new_lr, self.min_lr)
185+
186+
if new_lr < 0.0:
187+
raise NegativeLRError(new_lr)
188+
185189
self.current_lr = new_lr
186190

187191
return new_lr
@@ -249,9 +253,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
249253
for group in self.param_groups:
250254
if len(self.state) == 0:
251255
continue
256+
252257
p = next(iter(self.state.keys()))
253-
lr = group["lr"]
254-
step = self.state[p]["step"]
258+
259+
lr = group['lr']
260+
step = self.state[p]['step']
255261

256262
beta1, beta2 = group['betas']
257263
bias_correction1 = 1.0 - beta1 ** step # fmt: skip
@@ -264,17 +270,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
264270

265271
# warm down
266272
lr = self.get_warm_down(lr, step)
267-
if lr < 0.0:
268-
raise NegativeLRError(lr)
269273

270274
# stable decay
271275
decay = group['weight_decay']
272276
if decay:
273277
p.mul_(1.0 - decay * lr / variance_normalized)
274278

275279
# norm loss
276-
u_norm = unit_norm(p)
277-
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, u_norm + self.eps))
280+
correction = 2.0 * self.norm_loss_factor * (1.0 - torch.div(1, unit_norm(p) + self.eps))
278281
p.mul_(1.0 - lr * correction)
279282

280283
for p in group['params']:

0 commit comments

Comments
 (0)