22from typing import Optional
33
44import torch
5- from torch .nn import functional as f
5+ from torch .nn . functional import softplus
66
77from pytorch_optimizer .base .exception import NoSparseGradientError , ZeroParameterSizeError
88from pytorch_optimizer .base .optimizer import BaseOptimizer
@@ -38,6 +38,7 @@ class Ranger21(BaseOptimizer):
3838 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
3939 :param use_softplus: bool. use softplus to smooth.
4040 :param beta_softplus: float. beta.
41+ :param disable_lr_scheduler: bool. whether to disable learning rate schedule.
4142 :param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate
4243 warmup.
4344 :param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit
@@ -65,6 +66,7 @@ def __init__( # pylint: disable=R0913
6566 betas : BETAS = (0.9 , 0.999 ),
6667 use_softplus : bool = True ,
6768 beta_softplus : float = 50.0 ,
69+ disable_lr_scheduler : bool = False ,
6870 num_warm_up_iterations : Optional [int ] = None ,
6971 num_warm_down_iterations : Optional [int ] = None ,
7072 warm_down_min_lr : float = 3e-5 ,
@@ -93,6 +95,7 @@ def __init__( # pylint: disable=R0913
9395 self .min_lr = warm_down_min_lr
9496 self .use_softplus = use_softplus
9597 self .beta_softplus = beta_softplus
98+ self .disable_lr_scheduler = disable_lr_scheduler
9699 self .agc_clipping_value = agc_clipping_value
97100 self .agc_eps = agc_eps
98101 self .centralize_gradients = centralize_gradients
@@ -245,8 +248,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
245248 noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 ) # fmt: skip
246249
247250 # warm up & down
248- lr : float = self .warm_up_dampening (group ['lr' ], group ['step' ])
249- lr = self .warm_down (lr , group ['step' ])
251+ if self .disable_lr_scheduler :
252+ lr : float = group ['lr' ]
253+ else :
254+ lr : float = self .warm_up_dampening (group ['lr' ], group ['step' ])
255+ lr = self .warm_down (lr , group ['step' ])
250256
251257 for p in group ['params' ]:
252258 if p .grad is None :
@@ -279,7 +285,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
279285 de_nom = (variance_ma .sqrt () / bias_correction2_sq ).add_ (group ['eps' ])
280286
281287 if self .use_softplus :
282- de_nom = f . softplus (de_nom , beta = self .beta_softplus )
288+ de_nom = softplus (de_nom , beta = self .beta_softplus )
283289
284290 grad = p .grad
285291 centralize_gradient (grad , gc_conv_only = False )
@@ -289,7 +295,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
289295
290296 step_size : float = self .apply_adam_debias (group ['adam_debias' ], lr , bias_correction1 )
291297
292- pn_momentum = grad_ma .mul (1.0 + 1.0 ).add (neg_grad_ma , alpha = - 1.0 ).mul (1.0 / noise_norm )
298+ pn_momentum = grad_ma .mul (1.0 + 1.0 ).add_ (neg_grad_ma , alpha = - 1.0 ).mul_ (1.0 / noise_norm )
293299 p .addcdiv_ (pn_momentum , de_nom , value = - step_size )
294300
295301 self .lookahead_process_step ()
0 commit comments