Skip to content

Commit 6ba157d

Browse files
authored
Merge pull request #56 from kozistr/feature/pnm-optimizer
[Feature] PNM optimizer
2 parents 34e6fd0 + 1723f3a commit 6ba157d

File tree

15 files changed

+380
-85
lines changed

15 files changed

+380
-85
lines changed

Pipfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ verify_ssl = false
66
[dev-packages]
77
isort = "==5.10.1"
88
black = "==21.12b0"
9+
click = "==8.0.4"
910
pylint = "==2.11.1"
1011
pytest = "==7.0.1"
1112
pytest-cov = "==3.0.0"
1213

1314
[packages]
1415
numpy = "==1.21.4"
15-
torch = "==1.10.1"
16+
torch = "==1.11.0"
1617

1718
[requires]
1819
python_version = "3"

Pipfile.lock

Lines changed: 70 additions & 69 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytorch_optimizer.adabound import AdaBound
44
from pytorch_optimizer.adahessian import AdaHessian
55
from pytorch_optimizer.adamp import AdamP
6+
from pytorch_optimizer.adapnm import AdaPNM
67
from pytorch_optimizer.agc import agc
78
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
89
from pytorch_optimizer.diffgrad import DiffGrad
@@ -15,6 +16,7 @@
1516
from pytorch_optimizer.madgrad import MADGRAD
1617
from pytorch_optimizer.optimizers import load_optimizers
1718
from pytorch_optimizer.pcgrad import PCGrad
19+
from pytorch_optimizer.pnm import PNM
1820
from pytorch_optimizer.radam import RAdam
1921
from pytorch_optimizer.ralamb import RaLamb
2022
from pytorch_optimizer.ranger import Ranger

pytorch_optimizer/adapnm.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base_optimizer import BaseOptimizer
7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
8+
9+
10+
class AdaPNM(Optimizer, BaseOptimizer):
11+
"""
12+
Reference : https://github.com/zeke-xie/Positive-Negative-Momentum
13+
Example :
14+
from pytorch_optimizer import AdaPNM
15+
...
16+
model = YourModel()
17+
optimizer = AdaPNM(model.parameters())
18+
...
19+
for input, output in data:
20+
optimizer.zero_grad()
21+
loss = loss_function(output, model(input))
22+
loss.backward()
23+
optimizer.step()
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: PARAMETERS,
29+
lr: float = 1e-3,
30+
betas: BETAS = (0.9, 0.999, 1.0),
31+
weight_decay: float = 0.0,
32+
weight_decouple: bool = True,
33+
amsgrad: bool = True,
34+
eps: float = 1e-8,
35+
):
36+
"""AdaPNM optimizer
37+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
38+
:param lr: float. learning rate
39+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
40+
:param weight_decay: float. weight decay (L2 penalty)
41+
:param weight_decouple: bool. use weight_decouple
42+
:param amsgrad: bool. whether to use the AMSGrad variant of this algorithm from the paper
43+
:param eps: float. term added to the denominator to improve numerical stability
44+
"""
45+
self.lr = lr
46+
self.betas = betas
47+
self.weight_decay = weight_decay
48+
self.weight_decouple = weight_decouple
49+
self.amsgrad = amsgrad
50+
self.eps = eps
51+
52+
self.validate_parameters()
53+
54+
defaults: DEFAULTS = dict(
55+
lr=lr, betas=betas, weight_decay=weight_decay, weight_decouple=weight_decouple, amsgrad=amsgrad, eps=eps
56+
)
57+
super().__init__(params, defaults)
58+
59+
def validate_parameters(self):
60+
self.validate_learning_rate(self.lr)
61+
self.validate_betas(self.betas)
62+
self.validate_weight_decay(self.weight_decay)
63+
self.validate_epsilon(self.eps)
64+
65+
@torch.no_grad()
66+
def reset(self):
67+
for group in self.param_groups:
68+
for p in group['params']:
69+
state = self.state[p]
70+
71+
state['step'] = 0
72+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
73+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
74+
state['neg_exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
75+
76+
if group['amsgrad']:
77+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
78+
79+
@torch.no_grad()
80+
def step(self, closure: CLOSURE = None) -> LOSS:
81+
loss: LOSS = None
82+
if closure is not None:
83+
with torch.enable_grad():
84+
loss = closure()
85+
86+
for group in self.param_groups:
87+
for p in group['params']:
88+
if p.grad is None:
89+
continue
90+
91+
grad = p.grad
92+
if grad.is_sparse:
93+
raise RuntimeError('AdaPNM does not support sparse gradients')
94+
95+
if group['weight_decouple']:
96+
p.mul_(1.0 - group['lr'] * group['weight_decay'])
97+
else:
98+
grad.add_(p, alpha=group['weight_decay'])
99+
100+
state = self.state[p]
101+
if len(state) == 0:
102+
state['step'] = 0
103+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
104+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
105+
state['neg_exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
106+
107+
if group['amsgrad']:
108+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
109+
110+
state['step'] += 1
111+
beta1, beta2, beta3 = group['betas']
112+
113+
bias_correction1 = 1 - beta1 ** state['step']
114+
bias_correction2 = 1 - beta2 ** state['step']
115+
116+
exp_avg_sq = state['exp_avg_sq']
117+
if state['step'] % 2 == 1:
118+
exp_avg, neg_exp_avg = state['exp_avg'], state['neg_exp_avg']
119+
else:
120+
exp_avg, neg_exp_avg = state['neg_exp_avg'], state['exp_avg']
121+
122+
exp_avg.mul_(beta1 ** 2).add_(grad, alpha=1 - beta1 ** 2)
123+
noise_norm = math.sqrt((1 + beta3) ** 2 + beta3 ** 2)
124+
125+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
126+
if group['amsgrad']:
127+
exp_avg_sq = torch.max(state['max_exp_avg_sq'], exp_avg_sq)
128+
129+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
130+
131+
step_size = group['lr'] / bias_correction1
132+
133+
pn_momentum = exp_avg.mul(1 + beta3).add(neg_exp_avg, alpha=-beta3).mul(1.0 / noise_norm)
134+
p.addcdiv_(pn_momentum, denom, value=-step_size)
135+
136+
return loss

0 commit comments

Comments
 (0)