Skip to content

Commit 6d30fa1

Browse files
committed
update: D-Adaptation v3
1 parent dcfdb9f commit 6d30fa1

File tree

1 file changed

+59
-80
lines changed

1 file changed

+59
-80
lines changed

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 59 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytorch_optimizer.base.exception import NoSparseGradientError
1313
from pytorch_optimizer.base.optimizer import BaseOptimizer
1414
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
15-
from pytorch_optimizer.optimizer.utils import to_real
15+
from pytorch_optimizer.optimizer.utils import get_global_gradient_norm, to_real
1616

1717

1818
class DAdaptAdaGrad(Optimizer, BaseOptimizer):
@@ -23,7 +23,6 @@ class DAdaptAdaGrad(Optimizer, BaseOptimizer):
2323
:param momentum: float. momentum.
2424
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
2525
:param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
26-
Default is inf, for unrestricted.
2726
:param weight_decay: float. weight decay (L2 penalty).
2827
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2928
:param fixed_decay: bool. fix weight decay.
@@ -253,11 +252,10 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
253252
:param betas: BETAS. betas.
254253
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
255254
:param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
256-
Default is inf, for unrestricted.
257255
:param weight_decay: float. weight decay (L2 penalty).
258256
:param weight_decouple: bool. use AdamW style weight decay.
259257
:param fixed_decay: bool. fix weight decay.
260-
:param bias_correction: bool. Turn on Adam's bias correction.
258+
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
261259
:param eps: float. term added to the denominator to improve numerical stability.
262260
"""
263261

@@ -271,7 +269,7 @@ def __init__(
271269
weight_decay: float = 0.0,
272270
weight_decouple: bool = False,
273271
fixed_decay: bool = False,
274-
bias_correction: bool = False,
272+
adam_debias: bool = False,
275273
eps: float = 0.0,
276274
):
277275
self.validate_learning_rate(lr)
@@ -287,8 +285,8 @@ def __init__(
287285
'weight_decay': weight_decay,
288286
'weight_decouple': weight_decouple,
289287
'fixed_decay': fixed_decay,
290-
'bias_correction': bias_correction,
291-
'k': 0,
288+
'adam_debias': adam_debias,
289+
'step': 0,
292290
'eps': eps,
293291
}
294292
super().__init__(params, defaults)
@@ -299,13 +297,13 @@ def __str__(self) -> str:
299297
@torch.no_grad()
300298
def reset(self):
301299
for group in self.param_groups:
300+
group['step'] = 0
302301
for p in group['params']:
303302
if p.grad is None:
304303
continue
305304

306305
state = self.state[p]
307306

308-
state['step'] = 0
309307
state['s'] = torch.zeros_like(p)
310308
state['exp_avg'] = torch.zeros_like(p)
311309
state['exp_avg_sq'] = torch.zeros_like(p)
@@ -318,26 +316,25 @@ def step(self, closure: CLOSURE = None) -> LOSS:
318316
loss = closure()
319317

320318
group = self.param_groups[0]
319+
device = group['params'][0].device
321320

322321
beta1, beta2 = group['betas']
323-
k: int = group['k']
324322

325323
beta2_sq: float = math.sqrt(beta2)
326324

327325
d: float = group['d']
328-
lr: float = max(group['lr'] for group in self.param_groups)
329-
bias_correction: float = (
330-
((1.0 - beta2 ** (k + 1)) ** 0.5) / (1.0 - beta1 ** (k + 1)) if group['bias_correction'] else 1.0
331-
)
332-
d_lr = float(d * lr * bias_correction)
326+
lr: float = group['lr']
327+
328+
bias_correction: float = 1.0 - pow(beta1, group['step'] + 1)
329+
d_lr: float = self.apply_adam_debias(group['adam_debias'], step_size=d * lr, bias_correction1=bias_correction)
330+
331+
sk_l1 = torch.tensor([0.0], device=device)
332+
numerator_acc = torch.tensor([0.0], device=device)
333333

334334
if 'numerator_weighted' not in group:
335-
group['numerator_weighted'] = torch.tensor([0.0], device=group['params'][0].device)
335+
group['numerator_weighted'] = torch.tensor([0.0], device=device)
336336
numerator_weighted = group['numerator_weighted']
337337

338-
sk_l1 = torch.tensor([0.0], device=group['params'][0].device)
339-
numerator_acc = torch.tensor([0.0], device=group['params'][0].device)
340-
341338
for group in self.param_groups:
342339
for p in group['params']:
343340
if p.grad is None:
@@ -349,19 +346,17 @@ def step(self, closure: CLOSURE = None) -> LOSS:
349346

350347
state = self.state[p]
351348
if 'step' not in state:
352-
state['step'] = 0
353349
state['s'] = torch.zeros_like(p)
354350
state['exp_avg'] = torch.zeros_like(p)
355351
state['exp_avg_sq'] = torch.zeros_like(p)
356352

357-
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
358-
s = state['s']
353+
exp_avg, exp_avg_sq, s = state['exp_avg'], state['exp_avg_sq'], state['s']
359354

360355
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
361356
numerator_acc.add_(torch.dot(grad.flatten(), s.div(de_nom).flatten()), alpha=d_lr)
362357

363358
exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
364-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, alpha=1.0 - beta2)
359+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
365360

366361
s.mul_(beta2_sq).add_(grad, alpha=d_lr * (1.0 - beta2_sq))
367362

@@ -374,19 +369,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
374369

375370
if lr > 0.0:
376371
d_hat = numerator_weighted / (1.0 - beta2_sq) * sk_l1
377-
d = max(d, min(d_hat, d * group['growth_rate']))
372+
d = max(d, min(d_hat.item(), d * group['growth_rate']))
378373

379374
for group in self.param_groups:
380375
group['numerator_weighted'] = numerator_weighted
381376
group['d'] = d
377+
382378
for p in group['params']:
383379
if p.grad is None:
384380
continue
385381

386382
state = self.state[p]
387383

388-
state['step'] += 1
389-
390384
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
391385

392386
de_nom = exp_avg_sq.sqrt().add_(group['eps'])
@@ -400,22 +394,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
400394
fixed_decay=group['fixed_decay'],
401395
)
402396

403-
p.addcdiv_(exp_avg, de_nom, value=-1)
404-
405-
group['k'] += 1
397+
p.addcdiv_(exp_avg, de_nom, value=-1.0)
406398

407399
return loss
408400

409401

410402
class DAdaptSGD(Optimizer, BaseOptimizer):
411-
r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability.
403+
r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.
412404
413405
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
414406
:param lr: float. learning rate.
415407
:param momentum: float. momentum.
416408
:param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
417409
:param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
418-
Default is inf, for unrestricted.
419410
:param weight_decay: float. weight decay (L2 penalty).
420411
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
421412
:param fixed_decay: bool. fix weight decay.
@@ -425,15 +416,15 @@ def __init__(
425416
self,
426417
params: PARAMETERS,
427418
lr: float = 1.0,
428-
momentum: float = 0.0,
419+
momentum: float = 0.9,
429420
d0: float = 1e-6,
430421
growth_rate: float = float('inf'),
431422
weight_decay: float = 0.0,
432423
weight_decouple: bool = False,
433424
fixed_decay: bool = False,
434425
):
435426
self.validate_learning_rate(lr)
436-
self.validate_range(momentum, 'momentum', 0.0, 1.0)
427+
self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
437428
self.validate_non_negative(weight_decay, 'weight_decay')
438429

439430
defaults: DEFAULTS = {
@@ -444,7 +435,7 @@ def __init__(
444435
'weight_decay': weight_decay,
445436
'weight_decouple': weight_decouple,
446437
'fixed_decay': fixed_decay,
447-
'k': 0,
438+
'step': 0,
448439
}
449440
super().__init__(params, defaults)
450441

@@ -454,16 +445,16 @@ def __str__(self) -> str:
454445
@torch.no_grad()
455446
def reset(self):
456447
for group in self.param_groups:
448+
group['step'] = 0
457449
for p in group['params']:
458450
if p.grad is None:
459451
continue
460452

461453
state = self.state[p]
462454

463-
state['step'] = 0
455+
state['z'] = p.clone()
464456
state['s'] = torch.zeros_like(p)
465-
state['exp_avg'] = torch.zeros_like(p)
466-
state['exp_avg_sq'] = torch.zeros_like(p)
457+
state['x0'] = p.clone()
467458

468459
@torch.no_grad()
469460
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -473,14 +464,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
473464
loss = closure()
474465

475466
group = self.param_groups[0]
467+
device = group['params'][0].device
476468

477-
growth_rate = group['growth_rate']
469+
sk_sq = torch.tensor([0.0], device=device)
470+
if 'numerator_weighted' not in group:
471+
group['numerator_weighted'] = torch.tensor([0.0], device=device)
472+
numerator_weighted = group['numerator_weighted']
478473

479-
g_sq = torch.tensor([0.0], device=group['params'][0].device)
480-
sk_sq = torch.tensor([0.0], device=group['params'][0].device)
481-
if 'gsq_weighted' not in group:
482-
group['gsq_weighted'] = torch.tensor([0.0], device=group['params'][0].device)
483-
gsq_weighted = group['gsq_weighted']
474+
if group['step'] == 0:
475+
group['g0_norm'] = get_global_gradient_norm(self.param_groups, device).sqrt_().item()
476+
g0_norm = group['g0_norm']
477+
478+
if g0_norm == 0:
479+
return loss
480+
481+
d, lr = group['d'], group['lr']
482+
d_lr: float = d * lr / g0_norm
484483

485484
for group in self.param_groups:
486485
for p in group['params']:
@@ -491,57 +490,39 @@ def step(self, closure: CLOSURE = None) -> LOSS:
491490
if grad.is_sparse:
492491
raise NoSparseGradientError(str(self))
493492

493+
state = self.state[p]
494+
if len(state) == 0:
495+
state['z'] = p.clone()
496+
state['s'] = torch.zeros_like(p)
497+
state['x0'] = p.clone()
498+
494499
self.apply_weight_decay(
495500
p=p,
496-
grad=grad,
497-
lr=group['lr'],
501+
grad=None,
502+
lr=d_lr,
498503
weight_decay=group['weight_decay'],
499504
weight_decouple=group['weight_decouple'],
500505
fixed_decay=group['fixed_decay'],
501506
)
502507

503-
state = self.state[p]
504-
if 'z' not in state:
505-
state['z'] = torch.clone(p)
506-
state['s'] = torch.zeros_like(p)
507-
state['x0'] = torch.clone(p)
508-
509-
g_sq.add_(grad.pow(2).sum())
510-
511-
if g_sq == 0:
512-
return loss
513-
514-
group = self.param_groups[0]
515-
516-
if group['k'] == 0:
517-
group['g0_norm'] = g_sq.sqrt().item()
518-
g0_norm = group['g0_norm']
519-
520-
d, lr = group['d'], group['lr']
521-
d_lr = float(d * lr) / g0_norm
522-
523-
for group in self.param_groups:
524-
for p in group['params']:
525-
if p.grad is None:
526-
continue
527-
528-
state = self.state[p]
529-
530508
s = state['s']
531-
s.add_(p.grad, alpha=d_lr)
509+
numerator_weighted.add_(torch.dot(grad.flatten(), s.flatten()), alpha=d_lr)
532510

511+
s.add_(grad, alpha=d_lr)
533512
sk_sq.add_(s.pow(2).sum())
534513

535-
gsq_weighted.add_(g_sq, alpha=d_lr ** 2) # fmt: skip
536-
537514
if lr > 0.0:
538-
d_hat = (sk_sq - gsq_weighted) / sk_sq.sqrt()
539-
d = max(d, min(d_hat, d * growth_rate))
515+
d_hat = 2.0 * numerator_weighted / sk_sq.sqrt()
516+
d = max(d, min(d_hat.item(), d * group['growth_rate']))
540517

541518
for group in self.param_groups:
542-
group['gsq_weighted'] = gsq_weighted
519+
if 'step' in group:
520+
group['step'] += 1
521+
else:
522+
group['step'] = 1
523+
524+
group['numerator_weighted'] = numerator_weighted
543525
group['d'] = d
544-
group['g0_norm'] = g0_norm
545526

546527
for p in group['params']:
547528
if p.grad is None:
@@ -554,8 +535,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
554535

555536
p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])
556537

557-
group['k'] += 1
558-
559538
return loss
560539

561540

0 commit comments

Comments
 (0)