Skip to content

Commit bc22b00

Browse files
authored
Merge pull request #27 from kozistr/feature/diffrgrad-optimizer
[Feature] Implement DiffRGrad optimizer
2 parents 3c952d0 + 9b3c3d2 commit bc22b00

File tree

4 files changed

+177
-3
lines changed

4 files changed

+177
-3
lines changed

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytorch_optimizer.agc import agc
77
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
88
from pytorch_optimizer.diffgrad import DiffGrad
9+
from pytorch_optimizer.diffrgrad import DiffRGrad
910
from pytorch_optimizer.gc import centralize_gradient
1011
from pytorch_optimizer.lookahead import Lookahead
1112
from pytorch_optimizer.madgrad import MADGRAD
@@ -15,4 +16,4 @@
1516
from pytorch_optimizer.sam import SAM
1617
from pytorch_optimizer.sgdp import SGDP
1718

18-
__VERSION__ = '0.0.8'
19+
__VERSION__ = '0.0.9'

pytorch_optimizer/diffrgrad.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
8+
9+
class DiffRGrad(Optimizer):
10+
"""
11+
Reference 1 : https://github.com/shivram1987/diffGrad
12+
Reference 2 : https://github.com/LiyuanLucasLiu/RAdam
13+
Reference 3 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/diffgrad/diff_rgrad.py
14+
Example :
15+
from pytorch_optimizer import DiffRGrad
16+
...
17+
model = YourModel()
18+
optimizer = DiffRGrad(model.parameters())
19+
...
20+
for input, output in data:
21+
optimizer.zero_grad()
22+
loss = loss_function(output, model(input))
23+
loss.backward()
24+
optimizer.step()
25+
"""
26+
27+
def __init__(
28+
self,
29+
params: PARAMETERS,
30+
lr: float = 1e-3,
31+
betas: BETAS = (0.9, 0.999),
32+
weight_decay: float = 0.0,
33+
n_sma_threshold: int = 5,
34+
degenerated_to_sgd: bool = True,
35+
eps: float = 1e-8,
36+
):
37+
"""Blend RAdam with DiffGrad
38+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
39+
:param lr: float. learning rate.
40+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
41+
:param weight_decay: float. weight decay (L2 penalty)
42+
:param n_sma_threshold: int. (recommended is 5)
43+
:param degenerated_to_sgd: float.
44+
:param eps: float. term added to the denominator to improve numerical stability
45+
"""
46+
self.lr = lr
47+
self.betas = betas
48+
self.weight_decay = weight_decay
49+
self.n_sma_threshold = n_sma_threshold
50+
self.degenerated_to_sgd = degenerated_to_sgd
51+
self.eps = eps
52+
53+
self.check_valid_parameters()
54+
55+
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
56+
for param in params:
57+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
58+
param['buffer'] = [[None, None, None] for _ in range(10)]
59+
60+
defaults: DEFAULTS = dict(
61+
lr=lr,
62+
betas=betas,
63+
eps=eps,
64+
weight_decay=weight_decay,
65+
buffer=[[None, None, None] for _ in range(10)],
66+
)
67+
68+
super().__init__(params, defaults)
69+
70+
def check_valid_parameters(self):
71+
if self.lr < 0.0:
72+
raise ValueError(f'Invalid learning rate : {self.lr}')
73+
if self.weight_decay < 0.0:
74+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
75+
if not 0.0 <= self.betas[0] < 1.0:
76+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
77+
if not 0.0 <= self.betas[1] < 1.0:
78+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
79+
if self.eps < 0.0:
80+
raise ValueError(f'Invalid eps : {self.eps}')
81+
82+
def __setstate__(self, state: STATE):
83+
super().__setstate__(state)
84+
85+
def step(self, closure: CLOSURE = None) -> LOSS:
86+
loss: LOSS = None
87+
if closure is not None:
88+
loss = closure()
89+
90+
for group in self.param_groups:
91+
for p in group['params']:
92+
if p.grad is None:
93+
continue
94+
95+
grad = p.grad.data.float()
96+
if grad.is_sparse:
97+
raise RuntimeError('diffGrad does not support sparse gradients')
98+
99+
p_data_fp32 = p.data.float()
100+
state = self.state[p]
101+
102+
if len(state) == 0:
103+
state['step'] = 0
104+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
105+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
106+
state['previous_grad'] = torch.zeros_like(p_data_fp32)
107+
else:
108+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
109+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
110+
state['previous_grad'] = state['previous_grad'].type_as(p_data_fp32)
111+
112+
exp_avg, exp_avg_sq, previous_grad = (
113+
state['exp_avg'],
114+
state['exp_avg_sq'],
115+
state['previous_grad'],
116+
)
117+
beta1, beta2 = group['betas']
118+
119+
state['step'] += 1
120+
121+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
122+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
123+
124+
# compute diffGrad coefficient (dfc)
125+
diff = abs(previous_grad - grad)
126+
dfc = 1.0 / (1.0 + torch.exp(-diff))
127+
128+
state['previous_grad'] = grad.clone()
129+
130+
buffered = group['buffer'][int(state['step'] % 10)]
131+
if state['step'] == buffered[0]:
132+
n_sma, step_size = buffered[1], buffered[2]
133+
else:
134+
buffered[0] = state['step']
135+
beta2_t = beta2 ** state['step']
136+
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
137+
n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t)
138+
buffered[1] = n_sma
139+
140+
if n_sma >= self.n_sma_threshold:
141+
step_size = math.sqrt(
142+
(1 - beta2_t)
143+
* (n_sma - 4)
144+
/ (n_sma_max - 4)
145+
* (n_sma - 2)
146+
/ n_sma
147+
* n_sma_max
148+
/ (n_sma_max - 2)
149+
) / (1.0 - beta1 ** state['step'])
150+
elif self.degenerated_to_sgd:
151+
step_size = 1.0 / (1 - beta1 ** state['step'])
152+
else:
153+
step_size = -1
154+
buffered[2] = step_size
155+
156+
if n_sma >= self.n_sma_threshold:
157+
if group['weight_decay'] != 0:
158+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
159+
160+
denom = exp_avg_sq.sqrt().add_(group['eps'])
161+
162+
# update momentum with dfc
163+
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
164+
p.data.copy_(p_data_fp32)
165+
elif step_size > 0:
166+
if group['weight_decay'] != 0:
167+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
168+
169+
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
170+
p.data.copy_(p_data_fp32)
171+
172+
return loss

pytorch_optimizer/radam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ def __init__(
2828
params: PARAMETERS,
2929
lr: float = 1e-3,
3030
betas: BETAS = (0.9, 0.999),
31-
eps: float = 1e-8,
3231
weight_decay: float = 0.0,
3332
n_sma_threshold: int = 5,
3433
degenerated_to_sgd: bool = False,
34+
eps: float = 1e-8,
3535
):
3636
"""
3737
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3838
:param lr: float. learning rate.
3939
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
40-
:param eps: float. term added to the denominator to improve numerical stability
4140
:param weight_decay: float. weight decay (L2 penalty)
4241
:param n_sma_threshold: int. (recommended is 5)
4342
:param degenerated_to_sgd: float.
43+
:param eps: float. term added to the denominator to improve numerical stability
4444
"""
4545
self.lr = lr
4646
self.betas = betas

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def read_version() -> str:
5656
'sam',
5757
'asam',
5858
'diffgrad',
59+
'diffrgrad',
5960
]
6061
)
6162

0 commit comments

Comments
 (0)