Skip to content

Commit e66435a

Browse files
authored
Merge pull request #123 from kozistr/fix/pt2-lookahead
[Fix] `Lookahead` optimizer
2 parents 79afa2e + ccf301b commit e66435a

File tree

17 files changed

+496
-260
lines changed

17 files changed

+496
-260
lines changed

poetry.lock

Lines changed: 312 additions & 114 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.5.1"
3+
version = "2.5.2"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -37,13 +37,19 @@ numpy = [
3737
{ version = "=1.21.1", python = ">=3.7,<3.8" },
3838
{ version = "*", python = ">=3.8" },
3939
]
40-
torch = { version = ">=1.10", source = "torch" }
40+
torch = [
41+
{ version = ">=1.10,>=2.0", python = ">=3.8", source = "torch" },
42+
{ version = "^1.10", python = ">=3.7,<3.8", source = "torch" },
43+
]
4144

4245
[tool.poetry.dev-dependencies]
43-
isort = "^5.11.5"
44-
black = "^23.1.0"
45-
ruff = "^0.0.244"
46-
pytest = "^7.2.1"
46+
isort = [
47+
{ version = "==5.11.5", python = ">=3.7,<3.8"},
48+
{ version = "^5.12.0", python = ">=3.8"}
49+
]
50+
black = "^23.3.0"
51+
ruff = "^0.0.260"
52+
pytest = "^7.2.2"
4753
pytest-cov = "^4.0.0"
4854

4955
[[tool.poetry.source]]

pytorch_optimizer/base/exception.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class NoSparseGradientError(Exception):
66
"""
77

88
def __init__(self, optimizer_name: str, note: str = ''):
9-
self.note: str = ' ' if note == '' else f' w/ {note} '
9+
self.note: str = ' ' if not note else f' w/ {note} '
1010
self.message: str = f'[-] {optimizer_name}{self.note}does not support sparse gradient.'
1111
super().__init__(self.message)
1212

@@ -31,7 +31,7 @@ class NegativeLRError(Exception):
3131
"""Raised when learning rate is negative."""
3232

3333
def __init__(self, lr: float, lr_type: str = ''):
34-
self.note: str = 'learning rate' if lr_type == '' else lr_type
34+
self.note: str = lr_type if lr_type else 'learning rate'
3535
self.message: str = f'[-] {self.note} must be positive. ({lr} > 0)'
3636
super().__init__(self.message)
3737

@@ -40,6 +40,6 @@ class NegativeStepError(Exception):
4040
"""Raised when step is negative."""
4141

4242
def __init__(self, num_steps: int, step_type: str = ''):
43-
self.note: str = 'step' if step_type == '' else step_type
43+
self.note: str = step_type if step_type else 'step'
4444
self.message: str = f'[-] {self.note} must be positive. ({num_steps} > 0)'
4545
super().__init__(self.message)

pytorch_optimizer/lr_scheduler/cosine_anealing.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,20 @@ def step(self, epoch: Optional[int] = None):
9191
self.cur_cycle_steps = (
9292
int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
9393
)
94-
else:
95-
if epoch >= self.first_cycle_steps:
96-
if self.cycle_mult == 1.0:
97-
self.step_in_cycle = epoch % self.first_cycle_steps
98-
self.cycle = epoch // self.first_cycle_steps
99-
else:
100-
n: int = int(
101-
math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)
102-
)
103-
self.cycle = n
104-
self.step_in_cycle = epoch - int(
105-
self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)
106-
) # fmt: skip
107-
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n # fmt: skip
94+
elif epoch >= self.first_cycle_steps:
95+
if self.cycle_mult == 1.0:
96+
self.step_in_cycle = epoch % self.first_cycle_steps
97+
self.cycle = epoch // self.first_cycle_steps
10898
else:
109-
self.cur_cycle_steps = self.first_cycle_steps
110-
self.step_in_cycle = epoch
99+
n: int = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
100+
self.cycle = n
101+
self.step_in_cycle = epoch - int(
102+
self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)
103+
) # fmt: skip
104+
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n # fmt: skip
105+
else:
106+
self.cur_cycle_steps = self.first_cycle_steps
107+
self.step_in_cycle = epoch
111108

112109
self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) # fmt: skip
113110
self.last_epoch = math.floor(epoch)

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ def reset(self):
7777

7878
state = self.state[p]
7979

80-
try:
81-
state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
82-
except NotImplementedError: # there's no fill_() op for SpareTensorCPU
83-
state['alpha_k'] = torch.zeros_like(p)
84-
80+
state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
8581
state['sk'] = torch.zeros_like(p)
8682
state['x0'] = torch.clone(p)
8783
if p.grad.is_sparse:
@@ -119,11 +115,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
119115

120116
state = self.state[p]
121117
if 'alpha_k' not in state:
122-
try:
123-
state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
124-
except NotImplementedError: # there's no fill_() op for SpareTensorCPU
125-
state['alpha_k'] = torch.zeros_like(p)
126-
118+
state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
127119
state['sk'] = torch.zeros_like(p)
128120
state['x0'] = torch.clone(p)
129121
if grad.is_sparse:

pytorch_optimizer/optimizer/lookahead.py

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from typing import Dict
33

44
import torch
5-
from torch.optim import Optimizer
65

76
from pytorch_optimizer.base.optimizer import BaseOptimizer
8-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER, STATE
7+
from pytorch_optimizer.base.types import CLOSURE, LOSS, OPTIMIZER, STATE
98

109

11-
class Lookahead(Optimizer, BaseOptimizer):
10+
class Lookahead(BaseOptimizer):
1211
r"""k steps forward, 1 step back.
1312
1413
:param optimizer: OPTIMIZER. base optimizer.
@@ -17,7 +16,7 @@ class Lookahead(Optimizer, BaseOptimizer):
1716
:param pullback_momentum: str. change to inner optimizer momentum on interpolation update.
1817
"""
1918

20-
def __init__( # pylint: disable=super-init-not-called
19+
def __init__(
2120
self,
2221
optimizer: OPTIMIZER,
2322
k: int = 5,
@@ -32,62 +31,90 @@ def __init__( # pylint: disable=super-init-not-called
3231
self.validate_parameters()
3332

3433
self.param_groups = self.optimizer.param_groups
35-
self.fast_state: STATE = self.optimizer.state
3634
self.state: STATE = defaultdict(dict)
37-
self.reset()
3835

39-
self.defaults: DEFAULTS = optimizer.defaults
40-
self.defaults.update(
41-
{
42-
'k': k,
43-
'alpha': alpha,
44-
'pullback_momentum': pullback_momentum,
45-
}
46-
)
36+
for group in self.param_groups:
37+
if 'counter' not in group:
38+
group['counter'] = 0
39+
40+
for p in group['params']:
41+
state = self.state[p]
42+
state['slow_params'] = torch.empty_like(p)
43+
state['slow_params'].copy_(p)
44+
if self.pullback_momentum == 'pullback':
45+
state['slow_momentum'] = torch.zeros_like(p)
4746

4847
def validate_parameters(self):
4948
self.validate_lookahead_k(self.k)
5049
self.validate_alpha(self.alpha)
5150
self.validate_pullback_momentum(self.pullback_momentum)
5251

52+
def __getstate__(self):
53+
return {
54+
'state': self.state,
55+
'optimizer': self.optimizer,
56+
'alpha': self.alpha,
57+
'k': self.k,
58+
'pullback_momentum': self.pullback_momentum,
59+
}
60+
5361
@torch.no_grad()
5462
def reset(self):
5563
for group in self.param_groups:
5664
group['counter'] = 0
5765

66+
def backup_and_load_cache(self):
67+
r"""Backup cache parameters."""
68+
for group in self.param_groups:
69+
for p in group['params']:
70+
state = self.state[p]
71+
state['backup_params'] = torch.empty_like(p)
72+
state['backup_params'].copy_(p)
73+
p.data.copy_(state['slow_params'])
74+
75+
def clear_and_load_backup(self):
76+
r"""Load backup parameters."""
77+
for group in self.param_groups:
78+
for p in group['params']:
79+
state = self.state[p]
80+
p.data.copy_(state['backup_params'])
81+
del state['backup_params']
82+
83+
def state_dict(self) -> STATE:
84+
return self.optimizer.state_dict()
85+
86+
def load_state_dict(self, state: STATE):
87+
r"""Load state."""
88+
self.optimizer.load_state_dict(state)
89+
90+
@torch.no_grad()
91+
def zero_grad(self):
92+
self.optimizer.zero_grad(set_to_none=True)
93+
5894
@torch.no_grad()
5995
def update(self, group: Dict):
60-
for fast in group['params']:
61-
if fast.grad is None:
96+
for p in group['params']:
97+
if p.grad is None:
6298
continue
6399

64-
param_state = self.state[fast]
65-
if 'slow_param' not in param_state:
66-
param_state['slow_param'] = torch.empty_like(fast)
67-
param_state['slow_param'].copy_(fast)
68-
if self.pullback_momentum == 'pullback':
69-
param_state['slow_mom'] = torch.zeros_like(fast)
100+
state = self.state[p]
70101

71-
slow = param_state['slow_param']
72-
slow.add_(fast - slow, alpha=self.alpha)
102+
slow = state['slow_params']
73103

74-
fast.copy_(slow)
104+
p.mul_(self.alpha).add_(slow, alpha=1.0 - self.alpha)
105+
slow.copy_(p)
75106

76-
if 'momentum_buffer' not in self.optimizer.state[fast]:
77-
self.optimizer.state[fast]['momentum_buffer'] = torch.zeros_like(fast)
107+
if 'momentum_buffer' not in self.optimizer.state[p]:
108+
self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)
78109

79110
if self.pullback_momentum == 'pullback':
80-
internal_momentum = self.optimizer.state[fast]['momentum_buffer']
81-
self.optimizer.state[fast]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_(
82-
param_state['slow_mom'], alpha=1.0 - self.alpha
111+
internal_momentum = self.optimizer.state[p]['momentum_buffer']
112+
self.optimizer.state[p]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_(
113+
state['slow_momentum'], alpha=1.0 - self.alpha
83114
)
84-
param_state['slow_mom'] = self.optimizer.state[fast]['momentum_buffer']
115+
state['slow_momentum'] = self.optimizer.state[p]['momentum_buffer']
85116
elif self.pullback_momentum == 'reset':
86-
self.optimizer.state[fast]['momentum_buffer'] = torch.zeros_like(fast)
87-
88-
def update_lookahead(self):
89-
for group in self.param_groups:
90-
self.update(group)
117+
self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)
91118

92119
def step(self, closure: CLOSURE = None) -> LOSS:
93120
loss: LOSS = self.optimizer.step(closure)
@@ -97,25 +124,3 @@ def step(self, closure: CLOSURE = None) -> LOSS:
97124
group['counter'] = 0
98125
self.update(group)
99126
return loss
100-
101-
def state_dict(self) -> STATE:
102-
fast_state: STATE = self.optimizer.state_dict()
103-
slow_state: STATE = {(id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items()}
104-
105-
return {
106-
'fast_state': fast_state['state'],
107-
'slow_state': slow_state,
108-
'param_groups': fast_state['param_groups'],
109-
}
110-
111-
def load_state_dict(self, state: STATE):
112-
slow_state: STATE = {'state': state['slow_state'], 'param_groups': state['param_groups']}
113-
fast_state: STATE = {'state': state['fast_state'], 'param_groups': state['param_groups']}
114-
super().load_state_dict(slow_state)
115-
116-
self.optimizer.load_state_dict(fast_state)
117-
self.fast_state = self.optimizer.state
118-
119-
def add_param_group(self, param_group):
120-
param_group['counter'] = 0
121-
self.optimizer.add_param_group(param_group)

pytorch_optimizer/optimizer/madgrad.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def reset(self):
6969

7070
@torch.no_grad()
7171
def step(self, closure: CLOSURE = None) -> LOSS:
72-
# pylint: disable=W0212
73-
7472
loss: LOSS = None
7573
if closure is not None:
7674
with torch.enable_grad():
@@ -80,13 +78,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8078
if 'k' not in self.state:
8179
self.state['k'] = torch.tensor([0], dtype=torch.long, requires_grad=False)
8280

83-
k = self.state['k']
84-
8581
for group in self.param_groups:
8682
weight_decay, momentum, eps = group['weight_decay'], group['momentum'], group['eps']
8783
lr = group['lr'] + eps
8884

89-
_lambda = lr * math.pow(k + 1, 0.5)
85+
_lambda = lr * math.pow(self.state['k'] + 1, 0.5)
9086

9187
for p in group['params']:
9288
if p.grad is None:
@@ -105,7 +101,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
105101
raise NoSparseGradientError(str(self), note='momentum > 0.0')
106102

107103
grad_sum_sq, s = state['grad_sum_sq'], state['s']
108-
109104
if weight_decay > 0.0 and not self.decouple_decay:
110105
if grad.is_sparse:
111106
raise NoSparseGradientError(str(self), note='weight_decay')
@@ -120,11 +115,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
120115
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
121116
s_masked = s.sparse_mask(grad)
122117

123-
# Compute x_0 from other known quantities
124118
rms_masked_values = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
125119
x0_masked_values = p_masked._values().addcdiv(s_masked._values(), rms_masked_values, value=1)
126120

127-
# Dense + sparse op
128121
grad_sq = grad * grad
129122
grad_sum_sq.add_(grad_sq, alpha=_lambda)
130123
grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)

pytorch_optimizer/optimizer/pcgrad.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def set_grad(self, grads: List[torch.Tensor]):
4444
idx += 1
4545

4646
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
47-
r"""get the gradient of the parameters of the network with specific objective."""
47+
r"""Get the gradient of the parameters of the network with specific objective."""
4848
grad, shape, has_grad = [], [], []
4949
for group in self.optimizer.param_groups:
5050
for p in group['params']:
@@ -61,7 +61,7 @@ def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tenso
6161
return grad, shape, has_grad
6262

6363
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
64-
r"""pack the gradient of the parameters of the network for each objective.
64+
r"""Pack the gradient of the parameters of the network for each objective.
6565
6666
:param objectives: Iterable[nn.Module]. a list of objectives.
6767
:return: torch.Tensor. packed gradients.
@@ -80,7 +80,7 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List
8080
return grads, shapes, has_grads
8181

8282
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
83-
r"""project conflicting.
83+
r"""Project conflicting.
8484
8585
:param grads: a list of the gradient of the parameters.
8686
:param has_grads: a list of mask represent whether the parameter has gradient.
@@ -89,12 +89,12 @@ def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.T
8989
shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()
9090

9191
pc_grad: List[torch.Tensor] = deepcopy(grads)
92-
for g_i in pc_grad:
92+
for i, g_i in enumerate(pc_grad):
9393
random.shuffle(grads)
9494
for g_j in grads:
9595
g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
9696
if g_i_g_j < 0:
97-
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
97+
pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)
9898

9999
merged_grad: torch.Tensor = torch.zeros_like(grads[0], device=grads[0].device)
100100

@@ -109,7 +109,7 @@ def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.T
109109
return merged_grad
110110

111111
def pc_backward(self, objectives: Iterable[nn.Module]):
112-
r"""calculate the gradient of the parameters.
112+
r"""Calculate the gradient of the parameters.
113113
114114
:param objectives: Iterable[nn.Module]. a list of objectives.
115115
"""

pytorch_optimizer/optimizer/ranger21.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
217217
param_size += p.numel()
218218

219219
# Apply Adaptive Gradient Clipping (AGC)
220-
p = agc(p, agc_eps=self.agc_eps, agc_clip_val=self.agc_clipping_value)
220+
p = agc(p, agc_eps=self.agc_eps, agc_clip_val=self.agc_clipping_value) # noqa: PLW2901
221221

222222
state = self.state[p]
223223
if len(state) == 0:

0 commit comments

Comments
 (0)