77 '@TheZothen' ,
88]
99
10- import collections
1110import math
1211from typing import Dict , List , Optional
1312
3130
3231class Ranger21 (Optimizer ):
3332 """
34- Reference : https://github.com/lessw2020/Ranger21/blob/main/ranger21/ranger21.py
33+ Reference : https://github.com/lessw2020/Ranger21
3534 Example :
3635 from pytorch_optimizer import Ranger21
3736 ...
@@ -82,16 +81,18 @@ def __init__(
8281 decay_type : str = 'stable' ,
8382 warmup_type : str = 'linear' ,
8483 warmup_pct_default : float = 0.22 ,
85- logging_active : bool = False ,
8684 ):
87- """Ranger optimizer (RAdam + Lookahead + Gradient Centralization, combined into one optimizer)
88- :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
85+ """
86+ :param params: PARAMS. iterable of parameters to optimize
87+ or dicts defining parameter groups
8988 :param lr: float. learning rate.
90- :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
91- :param eps: float. term added to the denominator to improve numerical stability
89+ :param betas: BETAS. coefficients used for computing running averages
90+ of gradient and the squared hessian trace
91+ :param eps: float. term added to the denominator
92+ to improve numerical stability
9293 :param weight_decay: float. weight decay (L2 penalty)
93- :param use_gc: bool. use Gradient Centralization ( both convolution & fc layers)
94- :param gc_conv_only: bool. use Gradient Centralization ( only convolution layer)
94+ :param use_gc: bool. use GC both convolution & fc layers
95+ :param gc_conv_only: bool. use GC only convolution layer
9596 """
9697 defaults : DEFAULT_PARAMETERS = dict (
9798 lr = lr ,
@@ -102,8 +103,6 @@ def __init__(
102103 )
103104 super ().__init__ (params , defaults )
104105
105- self .logging = logging_active
106-
107106 self .use_madgrad = use_madgrad
108107 self .core_engine : str = self .get_core_engine (self .use_madgrad )
109108
@@ -207,9 +206,6 @@ def __init__(
207206 self .param_size : int = 0
208207
209208 self .tracking_lr : List [float ] = []
210- if self .logging :
211- self .tracking_variance_sum : List [float ] = []
212- self .tracking_variance_normalized = []
213209
214210 @staticmethod
215211 def get_core_engine (use_madgrad : bool = False ) -> str :
@@ -255,27 +251,18 @@ def warmup_dampening(self, lr: float, step: int) -> float:
255251
256252 if step > warmup :
257253 if not self .warmup_complete :
258- if not self .warmup_curr_pct == 1.0 :
259- print (
260- f'Error | lr did not achieve full set point from warmup, currently { self .warmup_curr_pct } '
261- )
262-
263254 self .warmup_complete = True
264- print (
265- f'\n ** Ranger21 update | Warmup complete - lr set to { lr } \n '
266- )
267-
268255 return lr
269256
270257 if style == 'linear' :
271258 self .warmup_curr_pct = min (1.0 , (step / warmup ))
272259 new_lr : float = lr * self .warmup_curr_pct
273260 self .current_lr = new_lr
274261 return new_lr
275- else :
276- raise NotImplementedError (
277- f'warmup style { style } is not supported yet :('
278- )
262+
263+ raise NotImplementedError (
264+ f'warmup style { style } is not supported yet :('
265+ )
279266
280267 def get_warm_down (self , lr : float , iteration : int ) -> float :
281268 if iteration < self .start_warm_down :
@@ -284,21 +271,18 @@ def get_warm_down(self, lr: float, iteration: int) -> float:
284271 if iteration > self .start_warm_down - 1 :
285272 # start iteration from 1, not 0
286273 warm_down_iteration : int = (iteration + 1 ) - self .start_warm_down
287- if warm_down_iteration < 1 :
288- warm_down_iteration = 1
274+ warm_down_iteration = max (warm_down_iteration , 1 )
289275
290276 warm_down_pct : float = warm_down_iteration / (
291277 self .warm_down_total_iterations + 1
292278 )
293- if warm_down_pct > 1.00 :
294- warm_down_pct = 1.00
279+ warm_down_pct = min (warm_down_pct , 1.0 )
295280
296281 lr_range : float = self .warm_down_lr_delta
297282 reduction : float = lr_range * warm_down_pct
298- new_lr : float = self .starting_lr - reduction
299- if new_lr < self .min_lr :
300- new_lr = self .min_lr
301283
284+ new_lr : float = self .starting_lr - reduction
285+ new_lr = max (new_lr , self .min_lr )
302286 self .current_lr = new_lr
303287
304288 return new_lr
@@ -323,21 +307,13 @@ def get_chebyshev_lr(self, lr: float, iteration: int) -> float:
323307 self .current_epoch = current_epoch
324308
325309 index : int = current_epoch - 2
326- if index < 0 :
327- index = 0
328- if index > len (self .chebyshev_schedule ) - 1 :
329- index = len (self .chebyshev_schedule ) - 1
310+ index = max (0 , index )
311+ index = min (index , len (self .chebyshev_schedule ) - 1 )
330312
331313 chebyshev_value = self .chebyshev_schedule [index ]
332314
333- if self .cheb_logging [:- 1 ] != chebyshev_value :
334- self .cheb_logging .append (chebyshev_value )
335-
336315 return lr * chebyshev_value
337316
338- def get_variance (self ):
339- return self .tracking_variance_sum
340-
341317 @staticmethod
342318 def get_state_values (group , state ):
343319 beta1 , beta2 = group ['betas' ]
@@ -348,16 +324,16 @@ def get_state_values(group, state):
348324 @torch .no_grad ()
349325 def step (self , closure : CLOSURE = None ) -> LOSS :
350326 loss : LOSS = None
351- if closure is not None and isinstance ( closure , collections . Callable ) :
327+ if closure is not None :
352328 with torch .enable_grad ():
353329 loss = closure ()
354330
355331 param_size : float = 0
356332 variance_ma_sum : float = 1.0
357333
358334 # phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay
359- for i , group in enumerate ( self .param_groups ) :
360- for j , p in enumerate ( group ['params' ]) :
335+ for group in self .param_groups :
336+ for p in group ['params' ]:
361337 if p .grad is None :
362338 continue
363339
@@ -369,7 +345,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
369345 )
370346
371347 grad = p .grad
372-
373348 if grad .is_sparse :
374349 raise RuntimeError ('sparse matrix not supported atm' )
375350
@@ -443,11 +418,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
443418 if math .isnan (variance_normalized ):
444419 raise RuntimeError ('hit nan for variance_normalized' )
445420
446- # debugging/logging
447- if self .logging :
448- self .tracking_variance_sum .append (variance_ma_sum .item ())
449- self .tracking_variance_normalized .append (variance_normalized )
450-
451421 # phase 2 - apply weight decay and step
452422 for group in self .param_groups :
453423 step = state ['step' ]
@@ -464,7 +434,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
464434 # warm-down
465435 if self .warm_down_active :
466436 lr = self .get_warm_down (lr , step )
467- if 0 > lr :
437+ if lr < 0.0 :
468438 raise ValueError (f'{ lr } went negative' )
469439
470440 # MADGRAD outer
0 commit comments