Skip to content

Commit 373e1b5

Browse files
authored
Merge pull request #121 from kozistr/update/nero-eps
[Update] Add epsilon to stablize the optimizng
2 parents df9e78d + fda5140 commit 373e1b5

File tree

6 files changed

+31
-35
lines changed

6 files changed

+31
-35
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ jobs:
3535
run: make check
3636
- name: Check test
3737
env:
38-
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
3938
PYTHONDONTWRITEBYTECODE: 1
4039
run: make test
4140
- name: Check codecov

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def update(self, group: Dict):
6969
param_state['slow_mom'] = torch.zeros_like(fast)
7070

7171
slow = param_state['slow_param']
72-
slow += (fast - slow) * self.alpha
72+
slow.add_(fast - slow, alpha=self.alpha)
73+
7374
fast.copy_(slow)
7475

7576
if 'momentum_buffer' not in self.optimizer.state[fast]:
@@ -98,30 +99,21 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9899
return loss
99100

100101
def state_dict(self) -> STATE:
101-
fast_state_dict: STATE = self.optimizer.state_dict()
102-
fast_state = fast_state_dict['state']
103-
param_groups = fast_state_dict['param_groups']
104-
102+
fast_state: STATE = self.optimizer.state_dict()
105103
slow_state: STATE = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()}
106104

107105
return {
108-
'fast_state': fast_state,
106+
'fast_state': fast_state['state'],
109107
'slow_state': slow_state,
110-
'param_groups': param_groups,
108+
'param_groups': fast_state['param_groups'],
111109
}
112110

113-
def load_state_dict(self, state_dict: STATE):
114-
slow_state_dict: STATE = {
115-
'state': state_dict['slow_state'],
116-
'param_groups': state_dict['param_groups'],
117-
}
118-
fast_state_dict: STATE = {
119-
'state': state_dict['fast_state'],
120-
'param_groups': state_dict['param_groups'],
121-
}
122-
super().load_state_dict(slow_state_dict)
111+
def load_state_dict(self, state: STATE):
112+
slow_state: STATE = {'state': state['slow_state'], 'param_groups': state['param_groups']}
113+
fast_state: STATE = {'state': state['fast_state'], 'param_groups': state['param_groups']}
114+
super().load_state_dict(slow_state)
123115

124-
self.optimizer.load_state_dict(fast_state_dict)
116+
self.optimizer.load_state_dict(fast_state)
125117
self.fast_state = self.optimizer.state
126118

127119
def add_param_group(self, param_group):

pytorch_optimizer/optimizer/nero.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ class Nero(Optimizer, BaseOptimizer):
1414
:param lr: float. learning rate.
1515
:param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
1616
:param constraints: bool.
17+
:param eps: float. term added to the denominator to improve numerical stability.
1718
"""
1819

19-
def __init__(self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, constraints: bool = True):
20+
def __init__(
21+
self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, constraints: bool = True, eps: float = 1e-8
22+
):
2023
self.lr = lr
2124
self.beta = beta
25+
self.eps = eps
2226

2327
self.validate_parameters()
2428

@@ -28,6 +32,7 @@ def __init__(self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, co
2832
def validate_parameters(self):
2933
self.validate_learning_rate(self.lr)
3034
self.validate_beta(self.beta)
35+
self.validate_epsilon(self.eps)
3136

3237
def __str__(self) -> str:
3338
return 'Nero'
@@ -38,7 +43,7 @@ def reset(self):
3843
for p in group['params']:
3944
if group['constraints'] and p.dim() > 1:
4045
p.sub_(neuron_mean(p))
41-
p.div_(neuron_norm(p))
46+
p.div_(neuron_norm(p) + self.eps)
4247

4348
state = self.state[p]
4449

@@ -69,7 +74,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
6974
if len(state) == 0:
7075
if group['constraints'] and p.dim() > 1:
7176
p.sub_(neuron_mean(p))
72-
p.div_(neuron_norm(p))
77+
p.div_(neuron_norm(p) + self.eps)
7378

7479
state['step'] = 0
7580
state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))
@@ -79,16 +84,20 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7984

8085
state['step'] += 1
8186

87+
grad_norm = neuron_norm(grad)
88+
89+
exp_avg_sq = state['exp_avg_sq']
90+
exp_avg_sq.mul_(self.beta).addcmul_(grad_norm, grad_norm, value=1.0 - self.beta)
91+
8292
bias_correction: float = 1.0 - self.beta ** state['step']
83-
state['exp_avg_sq'] = self.beta * state['exp_avg_sq'] + (1.0 - self.beta) * neuron_norm(grad) ** 2
8493

85-
grad_normed = grad / (state['exp_avg_sq'] / bias_correction).sqrt()
86-
grad_normed[torch.isnan(grad_normed)] = 0.0
94+
grad_normed = grad / ((exp_avg_sq / bias_correction).sqrt() + self.eps)
95+
torch.nan_to_num(grad_normed, nan=0.0, out=grad_normed)
8796

8897
p.sub_(group['lr'] * state['scale'] * grad_normed)
8998

9099
if group['constraints'] and p.dim() > 1:
91100
p.sub_(neuron_mean(p))
92-
p.div_(neuron_norm(p))
101+
p.div_(neuron_norm(p) + self.eps)
93102

94103
return loss

pytorch_optimizer/optimizer/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,19 @@ def neuron_norm(x: torch.Tensor) -> torch.Tensor:
199199
if x.dim() <= 1:
200200
return x.abs()
201201

202-
view_shape = [x.shape[0]] + [1] * (x.dim() - 1)
203-
x = x.view(x.shape[0], -1)
202+
view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1)
204203

205-
return x.norm(dim=1).view(*view_shape)
204+
return channel_view(x).norm(dim=1).view(*view_shape)
206205

207206

208207
def neuron_mean(x: torch.Tensor) -> torch.Tensor:
209208
r"""Get mean of the tensor."""
210209
if x.dim() <= 1:
211210
raise ValueError('[-] neuron_mean not defined on 1D tensors.')
212211

213-
view_shape = [x.shape[0]] + [1] * (x.dim() - 1)
214-
x = x.view(x.shape[0], -1)
212+
view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1)
215213

216-
return x.mean(dim=1).view(*view_shape)
214+
return channel_view(x).mean(dim=1).view(*view_shape)
217215

218216

219217
def disable_running_stats(model):

tests/test_gradients.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def test_sparse_not_supported(no_sparse_optimizer):
3939
opt = load_optimizer(optimizer=no_sparse_optimizer)
4040
optimizer = opt([param], num_iterations=1) if no_sparse_optimizer == 'ranger21' else opt([param])
4141

42-
optimizer.zero_grad()
43-
4442
with pytest.raises(NoSparseGradientError):
4543
optimizer.step(lambda: 0.1)
4644

tests/test_optimizer_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_learning_rate(optimizer_name):
2525

2626
@pytest.mark.parametrize('optimizer_name', VALID_OPTIMIZER_NAMES)
2727
def test_epsilon(optimizer_name):
28-
if optimizer_name in ('nero', 'shampoo', 'scalableshampoo', 'dadaptsgd', 'adafactor', 'lion'):
28+
if optimizer_name in ('shampoo', 'scalableshampoo', 'dadaptsgd', 'adafactor', 'lion'):
2929
pytest.skip(f'skip {optimizer_name} optimizer')
3030

3131
optimizer = load_optimizer(optimizer_name)

0 commit comments

Comments
 (0)