Skip to content

Commit 522dd44

Browse files
committed
fix: Kate optimizer
1 parent ecc4530 commit 522dd44

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

pytorch_optimizer/optimizer/kate.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ class Kate(Optimizer, BaseOptimizer):
1111
1212
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1313
:param lr: float. learning rate.
14-
:param eta: float. eta.
15-
:param delta: float. delta. 0.0 or 1e-8.
14+
:param delta: float. delta.
1615
:param weight_decay: float. weight decay (L2 penalty).
1716
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
1817
:param fixed_decay: bool. fix weight decay.
@@ -23,22 +22,19 @@ def __init__(
2322
self,
2423
params: PARAMETERS,
2524
lr: float = 1e-3,
26-
eta: float = 0.9,
2725
delta: float = 0.0,
2826
weight_decay: float = 0.0,
2927
weight_decouple: bool = True,
3028
fixed_decay: bool = False,
3129
eps: float = 1e-8,
3230
):
3331
self.validate_learning_rate(lr)
34-
self.validate_range(eta, 'eta', 0.0, 1.0, '[)')
3532
self.validate_range(delta, 'delta', 0.0, 1.0, '[)')
3633
self.validate_non_negative(weight_decay, 'weight_decay')
3734
self.validate_non_negative(eps, 'eps')
3835

3936
defaults: DEFAULTS = {
4037
'lr': lr,
41-
'eta': eta,
4238
'delta': delta,
4339
'weight_decay': weight_decay,
4440
'weight_decouple': weight_decouple,
@@ -97,14 +93,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9793
fixed_decay=group['fixed_decay'],
9894
)
9995

100-
grad_p2 = grad * grad
96+
grad_p2 = torch.mul(grad, grad)
10197

10298
m, b = state['m'], state['b']
103-
b.mul_(b).add_(grad_p2)
99+
b.mul_(b).add_(grad_p2).add_(group['eps'])
104100

105-
de_nom = b.add(group['eps'])
106-
107-
m.mul_(m).add_(grad_p2, alpha=group['eta']).add_(grad / de_nom).sqrt_()
101+
m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_()
108102

109103
update = m.mul(grad).div_(b)
110104

0 commit comments

Comments
 (0)