@@ -82,7 +82,7 @@ def __init__(
8282 decay_type : str = 'stable' ,
8383 warmup_type : str = 'linear' ,
8484 warmup_pct_default : float = 0.22 ,
85- logging_active : bool = True ,
85+ logging_active : bool = False ,
8686 ):
8787 """Ranger optimizer (RAdam + Lookahead + Gradient Centralization, combined into one optimizer)
8888 :param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
@@ -361,10 +361,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
361361 if p .grad is None :
362362 continue
363363
364- # if not self.param_size:
365364 param_size += p .numel ()
366365
367- # apply agc if enabled
368366 if self .agc_active :
369367 agc (
370368 p , agc_eps = self .agc_eps , agc_clip_val = self .agc_clip_val
@@ -377,43 +375,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
377375
378376 state = self .state [p ]
379377
380- # State initialization
381378 if len (state ) == 0 :
382379 state ['step' ] = 0
383-
384- # Exponential moving average of gradient values
385- state ['grad_ma' ] = torch .zeros_like (
386- p , memory_format = torch .preserve_format
387- )
388- # Exponential moving average of squared gradient values
389- state ['variance_ma' ] = torch .zeros_like (
390- p , memory_format = torch .preserve_format
391- )
380+ state ['grad_ma' ] = torch .zeros_like (p )
381+ state ['variance_ma' ] = torch .zeros_like (p )
392382
393383 if self .lookahead_active :
394384 state ['lookahead_params' ] = torch .zeros_like (p .data )
395385 state ['lookahead_params' ].copy_ (p .data )
396386
397387 if self .use_adabelief :
398- state ['variance_ma_belief' ] = torch .zeros_like (
399- p , memory_format = torch .preserve_format
400- )
388+ state ['variance_ma_belief' ] = torch .zeros_like (p )
401389 if self .momentum_pnm :
402- state ['neg_grad_ma' ] = torch .zeros_like (
403- p , memory_format = torch .preserve_format
404- )
405-
406- # Maintains max of all exp. moving avg. of sq. grad. values
407- state ['max_variance_ma' ] = torch .zeros_like (
408- p , memory_format = torch .preserve_format
409- )
390+ state ['neg_grad_ma' ] = torch .zeros_like (p )
391+ state ['max_variance_ma' ] = torch .zeros_like (p )
410392
411- # centralize gradients
412393 if self .use_gc :
413394 grad = centralize_gradient (
414395 grad ,
415396 gc_conv_only = self .gc_conv_only ,
416397 )
398+
417399 if self .use_gc_norm :
418400 grad = normalize_gradient (grad )
419401
0 commit comments