@@ -14,11 +14,15 @@ class Nero(Optimizer, BaseOptimizer):
1414 :param lr: float. learning rate.
1515 :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
1616 :param constraints: bool.
17+ :param eps: float. term added to the denominator to improve numerical stability.
1718 """
1819
19- def __init__ (self , params : PARAMETERS , lr : float = 0.01 , beta : float = 0.999 , constraints : bool = True ):
20+ def __init__ (
21+ self , params : PARAMETERS , lr : float = 0.01 , beta : float = 0.999 , constraints : bool = True , eps : float = 1e-8
22+ ):
2023 self .lr = lr
2124 self .beta = beta
25+ self .eps = eps
2226
2327 self .validate_parameters ()
2428
@@ -28,6 +32,7 @@ def __init__(self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, co
2832 def validate_parameters (self ):
2933 self .validate_learning_rate (self .lr )
3034 self .validate_beta (self .beta )
35+ self .validate_epsilon (self .eps )
3136
3237 def __str__ (self ) -> str :
3338 return 'Nero'
@@ -38,7 +43,7 @@ def reset(self):
3843 for p in group ['params' ]:
3944 if group ['constraints' ] and p .dim () > 1 :
4045 p .sub_ (neuron_mean (p ))
41- p .div_ (neuron_norm (p ))
46+ p .div_ (neuron_norm (p ) + self . eps )
4247
4348 state = self .state [p ]
4449
@@ -69,7 +74,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
6974 if len (state ) == 0 :
7075 if group ['constraints' ] and p .dim () > 1 :
7176 p .sub_ (neuron_mean (p ))
72- p .div_ (neuron_norm (p ))
77+ p .div_ (neuron_norm (p ) + self . eps )
7378
7479 state ['step' ] = 0
7580 state ['exp_avg_sq' ] = torch .zeros_like (neuron_norm (p ))
@@ -79,16 +84,20 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7984
8085 state ['step' ] += 1
8186
87+ grad_norm = neuron_norm (grad )
88+
89+ exp_avg_sq = state ['exp_avg_sq' ]
90+ exp_avg_sq .mul_ (self .beta ).addcmul_ (grad_norm , grad_norm , value = 1.0 - self .beta )
91+
8292 bias_correction : float = 1.0 - self .beta ** state ['step' ]
83- state ['exp_avg_sq' ] = self .beta * state ['exp_avg_sq' ] + (1.0 - self .beta ) * neuron_norm (grad ) ** 2
8493
85- grad_normed = grad / (state [ ' exp_avg_sq' ] / bias_correction ).sqrt ()
86- grad_normed [ torch .isnan (grad_normed )] = 0.0
94+ grad_normed = grad / (( exp_avg_sq / bias_correction ).sqrt () + self . eps )
95+ torch .nan_to_num (grad_normed , nan = 0.0 , out = grad_normed )
8796
8897 p .sub_ (group ['lr' ] * state ['scale' ] * grad_normed )
8998
9099 if group ['constraints' ] and p .dim () > 1 :
91100 p .sub_ (neuron_mean (p ))
92- p .div_ (neuron_norm (p ))
101+ p .div_ (neuron_norm (p ) + self . eps )
93102
94103 return loss
0 commit comments