Skip to content

Commit b8938e1

Browse files
committed
update: FAdam optimizer
1 parent f7c5ec0 commit b8938e1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytorch_optimizer/optimizer/fadam.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9999
momentum, fim = state['momentum'], state['fim']
100100
fim.mul_(curr_beta2).addcmul_(grad, grad, value=1.0 - curr_beta2)
101101

102-
rms_grad = torch.pow(grad, 2).mean().sqrt_()
102+
rms_grad = grad.pow(2).mean().sqrt_()
103103
curr_eps = min(rms_grad, 1) * group['eps']
104104

105-
fim_base = torch.pow(fim, group['p']).add_(curr_eps)
106-
grad_nat = torch.div(grad, fim_base)
105+
fim_base = fim.pow(group['p']).add_(curr_eps)
106+
grad_nat = grad / fim_base
107107

108-
rms = torch.pow(grad_nat, 2).mean().sqrt_()
108+
rms = grad_nat.pow(2).mean().sqrt_()
109109
divisor = max(1, rms) / group['clip']
110110
grad_nat.div_(divisor)
111111

@@ -119,6 +119,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119119

120120
grad_weights.mul_(group['weight_decay']).add_(momentum)
121121

122-
p.add_(-grad_weights, alpha=group['lr'])
122+
p.add_(grad_weights, alpha=-group['lr'])
123123

124124
return loss

0 commit comments

Comments
 (0)