Skip to content

Commit 50f7934

Browse files
committed
feature: initial DiffRGrad optimizer (copied from DiffGrad)
1 parent 224e54f commit 50f7934

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

pytorch_optimizer/diffrgrad.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS, STATE
7+
8+
9+
class DiffRGrad(Optimizer):
10+
"""
11+
Reference 1 : https://github.com/shivram1987/diffGrad
12+
Reference 2 : https://github.com/LiyuanLucasLiu/RAdam
13+
Reference 3 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/diffgrad/diff_rgrad.py
14+
Example :
15+
from pytorch_optimizer import DiffRGrad
16+
...
17+
model = YourModel()
18+
optimizer = DiffRGrad(model.parameters())
19+
...
20+
for input, output in data:
21+
optimizer.zero_grad()
22+
loss = loss_function(output, model(input))
23+
loss.backward()
24+
optimizer.step()
25+
"""
26+
27+
def __init__(
28+
self,
29+
params: PARAMETERS,
30+
lr: float = 1e-3,
31+
betas: BETAS = (0.9, 0.999),
32+
weight_decay: float = 0.0,
33+
degenerated_to_sgd: bool = True,
34+
eps: float = 1e-8,
35+
):
36+
"""Blend RAdam with DiffGrad
37+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
38+
:param lr: float. learning rate.
39+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
40+
:param weight_decay: float. weight decay (L2 penalty)
41+
:param degenerated_to_sgd: float.
42+
:param eps: float. term added to the denominator to improve numerical stability
43+
"""
44+
self.lr = lr
45+
self.betas = betas
46+
self.weight_decay = weight_decay
47+
self.degenerated_to_sgd = degenerated_to_sgd
48+
self.eps = eps
49+
50+
self.check_valid_parameters()
51+
52+
defaults: DEFAULTS = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
53+
super().__init__(params, defaults)
54+
55+
def check_valid_parameters(self):
56+
if self.lr < 0.0:
57+
raise ValueError(f'Invalid learning rate : {self.lr}')
58+
if self.weight_decay < 0.0:
59+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
60+
if not 0.0 <= self.betas[0] < 1.0:
61+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
62+
if not 0.0 <= self.betas[1] < 1.0:
63+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
64+
if self.eps < 0.0:
65+
raise ValueError(f'Invalid eps : {self.eps}')
66+
67+
def __setstate__(self, state: STATE):
68+
super().__setstate__(state)
69+
70+
def step(self, closure: CLOSURE = None) -> LOSS:
71+
loss: LOSS = None
72+
if closure is not None:
73+
loss = closure()
74+
75+
for group in self.param_groups:
76+
for p in group['params']:
77+
if p.grad is None:
78+
continue
79+
80+
grad = p.grad.data
81+
if grad.is_sparse:
82+
raise RuntimeError('diffGrad does not support sparse gradients')
83+
84+
state = self.state[p]
85+
86+
if len(state) == 0:
87+
state['step'] = 0
88+
state['exp_avg'] = torch.zeros_like(p.data)
89+
state['exp_avg_sq'] = torch.zeros_like(p.data)
90+
state['previous_grad'] = torch.zeros_like(p.data)
91+
92+
exp_avg, exp_avg_sq, previous_grad = (
93+
state['exp_avg'],
94+
state['exp_avg_sq'],
95+
state['previous_grad'],
96+
)
97+
beta1, beta2 = group['betas']
98+
99+
state['step'] += 1
100+
101+
if group['weight_decay'] != 0:
102+
grad.add_(group['weight_decay'], p.data)
103+
104+
# Decay the first and second moment running average coefficient
105+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
106+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
107+
denom = exp_avg_sq.sqrt().add_(group['eps'])
108+
109+
bias_correction1 = 1 - beta1 ** state['step']
110+
bias_correction2 = 1 - beta2 ** state['step']
111+
112+
# compute diffGrad coefficient (dfc)
113+
diff = abs(previous_grad - grad)
114+
dfc = 1.0 / (1.0 + torch.exp(-diff))
115+
state['previous_grad'] = grad.clone()
116+
117+
# update momentum with dfc
118+
exp_avg1 = exp_avg * dfc
119+
120+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
121+
122+
p.data.addcdiv_(-step_size, exp_avg1, denom)
123+
124+
return loss

0 commit comments

Comments
 (0)