Skip to content

Commit 0c0f48f

Browse files
[Update] Eliminate numpy usage and update learning rate application from EXAdam (#438)
Removed numpy dependency and adjusted step size calculation based on the latest update of the EXAdam paper https://arxiv.org/abs/2412.20302
1 parent 9dd83d0 commit 0c0f48f

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

pytorch_optimizer/optimizer/exadam.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import torch
32

43
from pytorch_optimizer.base.exception import NoSparseGradientError
@@ -38,8 +37,6 @@ def __init__(
3837

3938
self.maximize = maximize
4039

41-
self.sq2: float = np.sqrt(2)
42-
4340
defaults: DEFAULTS = {
4441
'lr': lr,
4542
'betas': betas,
@@ -88,8 +85,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8885
bias_correction1: float = self.debias(beta1, group['step'])
8986
bias_correction2: float = self.debias(beta2, group['step'])
9087

91-
step_size: float = group['lr'] * np.log(np.sqrt(group['step'] + 1) * self.sq2)
92-
9388
for p in group['params']:
9489
if p.grad is None:
9590
continue
@@ -128,6 +123,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
128123

129124
update = (m_tilde + g_tilde) / v_tilde.sqrt().add_(group['eps'])
130125

131-
p.add_(update, alpha=-step_size)
126+
p.add_(update, alpha=-group['lr'])
132127

133128
return loss

0 commit comments

Comments
 (0)