Skip to content

Commit 7800761

Browse files
committed
refactor: DAdaptAdaGrad
1 parent 6d30fa1 commit 7800761

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
eps: float = 0.0,
4343
):
4444
self.validate_learning_rate(lr)
45-
self.validate_range(momentum, 'momentum', 0.0, 1.0)
45+
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
4646
self.validate_non_negative(weight_decay, 'weight_decay')
4747
self.validate_non_negative(eps, 'eps')
4848

@@ -85,14 +85,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8585
loss = closure()
8686

8787
group = self.param_groups[0]
88-
89-
lr, momentum, growth_rate = group['lr'], group['momentum'], group['growth_rate']
90-
91-
d = group['d']
92-
d_lr = float(d * lr)
93-
9488
device = group['params'][0].device
9589

90+
d, lr = group['d'], group['lr']
91+
d_lr: float = d * lr
92+
9693
g_sq = torch.tensor([0.0], device=device)
9794
sk_sq_weighted_change = torch.tensor([0.0], device=device)
9895
sk_l1_change = torch.tensor([0.0], device=device)
@@ -199,7 +196,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
199196

200197
if lr > 0.0:
201198
d_hat = (sk_sq_weighted - gsq_weighted) / sk_l1
202-
d = group['d'] = max(d, min(d_hat, d * growth_rate))
199+
d = group['d'] = max(d, min(d_hat, d * group['growth_rate']))
203200

204201
for group in self.param_groups:
205202
group['gsq_weighted'] = gsq_weighted
@@ -212,11 +209,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
212209
continue
213210

214211
grad = p.grad
212+
215213
state = self.state[p]
216214

217-
alpha_k = state['alpha_k']
218-
sk = state['sk']
219-
x0 = state['x0']
215+
alpha_k, sk, x0 = state['alpha_k'], state['sk'], state['x0']
220216

221217
if grad.is_sparse:
222218
grad = grad.coalesce()
@@ -232,10 +228,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
232228
loc_delta = torch.sparse_coo_tensor(grad.indices(), loc_delta_masked, grad.shape)
233229
p.add_(loc_delta)
234230
else:
235-
z = x0 - sk.div(torch.sqrt(alpha_k) + group['eps'])
231+
z = x0 - sk.div(alpha_k.sqrt().add_(group['eps']))
236232

237-
if momentum > 0.0:
238-
p.mul_(momentum).add_(z, alpha=1.0 - momentum)
233+
if group['momentum'] > 0.0:
234+
p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])
239235
else:
240236
p.copy_(z)
241237

0 commit comments

Comments
 (0)