@@ -11,8 +11,7 @@ class Kate(Optimizer, BaseOptimizer):
1111
1212 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1313 :param lr: float. learning rate.
14- :param eta: float. eta.
15- :param delta: float. delta. 0.0 or 1e-8.
14+ :param delta: float. delta.
1615 :param weight_decay: float. weight decay (L2 penalty).
1716 :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
1817 :param fixed_decay: bool. fix weight decay.
@@ -23,22 +22,19 @@ def __init__(
2322 self ,
2423 params : PARAMETERS ,
2524 lr : float = 1e-3 ,
26- eta : float = 0.9 ,
2725 delta : float = 0.0 ,
2826 weight_decay : float = 0.0 ,
2927 weight_decouple : bool = True ,
3028 fixed_decay : bool = False ,
3129 eps : float = 1e-8 ,
3230 ):
3331 self .validate_learning_rate (lr )
34- self .validate_range (eta , 'eta' , 0.0 , 1.0 , '[)' )
3532 self .validate_range (delta , 'delta' , 0.0 , 1.0 , '[)' )
3633 self .validate_non_negative (weight_decay , 'weight_decay' )
3734 self .validate_non_negative (eps , 'eps' )
3835
3936 defaults : DEFAULTS = {
4037 'lr' : lr ,
41- 'eta' : eta ,
4238 'delta' : delta ,
4339 'weight_decay' : weight_decay ,
4440 'weight_decouple' : weight_decouple ,
@@ -97,14 +93,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9793 fixed_decay = group ['fixed_decay' ],
9894 )
9995
100- grad_p2 = grad * grad
96+ grad_p2 = torch . mul ( grad , grad )
10197
10298 m , b = state ['m' ], state ['b' ]
103- b .mul_ (b ).add_ (grad_p2 )
99+ b .mul_ (b ).add_ (grad_p2 ). add_ ( group [ 'eps' ])
104100
105- de_nom = b .add (group ['eps' ])
106-
107- m .mul_ (m ).add_ (grad_p2 , alpha = group ['eta' ]).add_ (grad / de_nom ).sqrt_ ()
101+ m .mul_ (m ).add_ (grad_p2 , alpha = group ['delta' ]).add_ (grad_p2 / b ).sqrt_ ()
108102
109103 update = m .mul (grad ).div_ (b )
110104
0 commit comments