Skip to content

Commit b85e0ad

Browse files
committed
feature: disable_lr_scheduler
1 parent 024e611 commit b85e0ad

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
import torch
5-
from torch.nn import functional as f
5+
from torch.nn.functional import softplus
66

77
from pytorch_optimizer.base.exception import NoSparseGradientError, ZeroParameterSizeError
88
from pytorch_optimizer.base.optimizer import BaseOptimizer
@@ -38,6 +38,7 @@ class Ranger21(BaseOptimizer):
3838
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
3939
:param use_softplus: bool. use softplus to smooth.
4040
:param beta_softplus: float. beta.
41+
:param disable_lr_scheduler: bool. whether to disable learning rate schedule.
4142
:param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate
4243
warmup.
4344
:param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit
@@ -65,6 +66,7 @@ def __init__( # pylint: disable=R0913
6566
betas: BETAS = (0.9, 0.999),
6667
use_softplus: bool = True,
6768
beta_softplus: float = 50.0,
69+
disable_lr_scheduler: bool = False,
6870
num_warm_up_iterations: Optional[int] = None,
6971
num_warm_down_iterations: Optional[int] = None,
7072
warm_down_min_lr: float = 3e-5,
@@ -93,6 +95,7 @@ def __init__( # pylint: disable=R0913
9395
self.min_lr = warm_down_min_lr
9496
self.use_softplus = use_softplus
9597
self.beta_softplus = beta_softplus
98+
self.disable_lr_scheduler = disable_lr_scheduler
9699
self.agc_clipping_value = agc_clipping_value
97100
self.agc_eps = agc_eps
98101
self.centralize_gradients = centralize_gradients
@@ -245,8 +248,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
245248
noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2) # fmt: skip
246249

247250
# warm up & down
248-
lr: float = self.warm_up_dampening(group['lr'], group['step'])
249-
lr = self.warm_down(lr, group['step'])
251+
if self.disable_lr_scheduler:
252+
lr: float = group['lr']
253+
else:
254+
lr: float = self.warm_up_dampening(group['lr'], group['step'])
255+
lr = self.warm_down(lr, group['step'])
250256

251257
for p in group['params']:
252258
if p.grad is None:
@@ -279,7 +285,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
279285
de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps'])
280286

281287
if self.use_softplus:
282-
de_nom = f.softplus(de_nom, beta=self.beta_softplus)
288+
de_nom = softplus(de_nom, beta=self.beta_softplus)
283289

284290
grad = p.grad
285291
centralize_gradient(grad, gc_conv_only=False)
@@ -289,7 +295,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
289295

290296
step_size: float = self.apply_adam_debias(group['adam_debias'], lr, bias_correction1)
291297

292-
pn_momentum = grad_ma.mul(1.0 + 1.0).add(neg_grad_ma, alpha=-1.0).mul(1.0 / noise_norm)
298+
pn_momentum = grad_ma.mul(1.0 + 1.0).add_(neg_grad_ma, alpha=-1.0).mul_(1.0 / noise_norm)
293299
p.addcdiv_(pn_momentum, de_nom, value=-step_size)
294300

295301
self.lookahead_process_step()

0 commit comments

Comments
 (0)