Skip to content

Commit 4300b0d

Browse files
committed
feature: DiffRGrad optimizer
1 parent 50f7934 commit 4300b0d

File tree

1 file changed

+67
-19
lines changed

1 file changed

+67
-19
lines changed

pytorch_optimizer/diffrgrad.py

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
lr: float = 1e-3,
3131
betas: BETAS = (0.9, 0.999),
3232
weight_decay: float = 0.0,
33+
n_sma_threshold: int = 5,
3334
degenerated_to_sgd: bool = True,
3435
eps: float = 1e-8,
3536
):
@@ -38,18 +39,32 @@ def __init__(
3839
:param lr: float. learning rate.
3940
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
4041
:param weight_decay: float. weight decay (L2 penalty)
42+
:param n_sma_threshold: int. (recommended is 5)
4143
:param degenerated_to_sgd: float.
4244
:param eps: float. term added to the denominator to improve numerical stability
4345
"""
4446
self.lr = lr
4547
self.betas = betas
4648
self.weight_decay = weight_decay
49+
self.n_sma_threshold = n_sma_threshold
4750
self.degenerated_to_sgd = degenerated_to_sgd
4851
self.eps = eps
4952

5053
self.check_valid_parameters()
5154

52-
defaults: DEFAULTS = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
55+
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
56+
for param in params:
57+
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
58+
param['buffer'] = [[None, None, None] for _ in range(10)]
59+
60+
defaults: DEFAULTS = dict(
61+
lr=lr,
62+
betas=betas,
63+
eps=eps,
64+
weight_decay=weight_decay,
65+
buffer=[[None, None, None] for _ in range(10)],
66+
)
67+
5368
super().__init__(params, defaults)
5469

5570
def check_valid_parameters(self):
@@ -77,17 +92,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7792
if p.grad is None:
7893
continue
7994

80-
grad = p.grad.data
95+
grad = p.grad.data.float()
8196
if grad.is_sparse:
8297
raise RuntimeError('diffGrad does not support sparse gradients')
8398

99+
p_data_fp32 = p.data.float()
84100
state = self.state[p]
85101

86102
if len(state) == 0:
87103
state['step'] = 0
88-
state['exp_avg'] = torch.zeros_like(p.data)
89-
state['exp_avg_sq'] = torch.zeros_like(p.data)
90-
state['previous_grad'] = torch.zeros_like(p.data)
104+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
105+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
106+
state['previous_grad'] = torch.zeros_like(p_data_fp32)
107+
else:
108+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
109+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
110+
state['previous_grad'] = state['previous_grad'].type_as(p_data_fp32)
91111

92112
exp_avg, exp_avg_sq, previous_grad = (
93113
state['exp_avg'],
@@ -98,27 +118,55 @@ def step(self, closure: CLOSURE = None) -> LOSS:
98118

99119
state['step'] += 1
100120

101-
if group['weight_decay'] != 0:
102-
grad.add_(group['weight_decay'], p.data)
103-
104-
# Decay the first and second moment running average coefficient
105121
exp_avg.mul_(beta1).add_(1 - beta1, grad)
106122
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
107-
denom = exp_avg_sq.sqrt().add_(group['eps'])
108-
109-
bias_correction1 = 1 - beta1 ** state['step']
110-
bias_correction2 = 1 - beta2 ** state['step']
111123

112124
# compute diffGrad coefficient (dfc)
113125
diff = abs(previous_grad - grad)
114126
dfc = 1.0 / (1.0 + torch.exp(-diff))
115-
state['previous_grad'] = grad.clone()
116-
117-
# update momentum with dfc
118-
exp_avg1 = exp_avg * dfc
119127

120-
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
128+
state['previous_grad'] = grad.clone()
121129

122-
p.data.addcdiv_(-step_size, exp_avg1, denom)
130+
buffered = group['buffer'][int(state['step'] % 10)]
131+
if state['step'] == buffered[0]:
132+
n_sma, step_size = buffered[1], buffered[2]
133+
else:
134+
buffered[0] = state['step']
135+
beta2_t = beta2 ** state['step']
136+
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
137+
n_sma = n_sma_max - 2.0 * state['step'] * beta2_t / (1.0 - beta2_t)
138+
buffered[1] = n_sma
139+
140+
if n_sma >= self.n_sma_threshold:
141+
step_size = math.sqrt(
142+
(1 - beta2_t)
143+
* (n_sma - 4)
144+
/ (n_sma_max - 4)
145+
* (n_sma - 2)
146+
/ n_sma
147+
* n_sma_max
148+
/ (n_sma_max - 2)
149+
) / (1.0 - beta1 ** state['step'])
150+
elif self.degenerated_to_sgd:
151+
step_size = 1.0 / (1 - beta1 ** state['step'])
152+
else:
153+
step_size = -1
154+
buffered[2] = step_size
155+
156+
if n_sma >= self.n_sma_threshold:
157+
if group['weight_decay'] != 0:
158+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
159+
160+
denom = exp_avg_sq.sqrt().add_(group['eps'])
161+
162+
# update momentum with dfc
163+
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg * dfc.float(), denom)
164+
p.data.copy_(p_data_fp32)
165+
elif step_size > 0:
166+
if group['weight_decay'] != 0:
167+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
168+
169+
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
170+
p.data.copy_(p_data_fp32)
123171

124172
return loss

0 commit comments

Comments
 (0)