Skip to content

Commit 7ded073

Browse files
authored
Merge pull request #134 from kozistr/feature/roto-grad-optimizer
[Feature] Implement RotoGrad optimizer
2 parents 8c7fd8e + a18fda1 commit 7ded073

File tree

17 files changed

+846
-109
lines changed

17 files changed

+846
-109
lines changed

README.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ You can check the supported optimizers & lr schedulers.
140140
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
141141
| SM3 | *Memory-Efficient Adaptive Optimization* | `github <https://github.com/google-research/google-research/tree/master/sm3>`__ | `https://arxiv.org/abs/1901.11150 <https://arxiv.org/abs/1901.11150>`__ |
142142
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
143-
| AdaNorm` | *Adaptive Gradient Norm Correction based Optimizer for CNNs* | `github <https://github.com/shivram1987/AdaNorm>`__ | `https://arxiv.org/abs/2210.06364 <https://arxiv.org/abs/2210.06364>`__ |
143+
| AdaNorm | *Adaptive Gradient Norm Correction based Optimizer for CNNs* | `github <https://github.com/shivram1987/AdaNorm>`__ | `https://arxiv.org/abs/2210.06364 <https://arxiv.org/abs/2210.06364>`__ |
144+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
145+
| RotoGrad | *Gradient Homogenization in Multitask Learning* | `github <https://github.com/adrianjav/rotograd>`__ | `https://openreview.net/pdf?id=T8wHz4rnuGL <https://openreview.net/pdf?id=T8wHz4rnuGL>`__ |
144146
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
145147

146148
Useful Resources
@@ -351,6 +353,8 @@ Citations
351353

352354
`AdaNorm <https://github.com/shivram1987/AdaNorm/tree/main#citation>`__
353355

356+
`RotoGrad <https://github.com/adrianjav/rotograd#citing>`__
357+
354358
Citation
355359
--------
356360

docs/optimizer_api.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ DAdaptSGD
217217
.. autoclass:: pytorch_optimizer.DAdaptSGD
218218
:members:
219219

220+
.. _DAdaptAdan:
221+
222+
DAdaptAdan
223+
----------
224+
225+
.. autoclass:: pytorch_optimizer.DAdaptAdan
226+
:members:
227+
220228
.. _AdamS:
221229

222230
AdamS
@@ -280,3 +288,11 @@ AdaNorm
280288

281289
.. autoclass:: pytorch_optimizer.AdaNorm
282290
:members:
291+
292+
.. _RotoGrad:
293+
294+
RotoGrad
295+
--------
296+
297+
.. autoclass:: pytorch_optimizer.RotoGrad
298+
:members:

docs/util_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,11 @@ reduce_max_except_dim
177177

178178
.. autoclass:: pytorch_optimizer.reduce_max_except_dim
179179
:members:
180+
181+
.. _get_global_gradient_norm:
182+
183+
get_global_gradient_norm
184+
------------------------
185+
186+
.. autoclass:: pytorch_optimizer.get_global_gradient_norm
187+
:members:

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.6.1"
3+
version = "2.7.0"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]
@@ -103,6 +103,7 @@ target-version = "py39"
103103
omit = [
104104
"./pytorch_optimizer/optimizer/gsam.py",
105105
"./pytorch_optimizer/optimizer/fp16.py",
106+
"./pytorch_optimizer/optimizer/rotograd.py",
106107
]
107108

108109
[build-system]

pytorch_optimizer/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_optimizer.optimizer.agc import agc
2929
from pytorch_optimizer.optimizer.alig import AliG
3030
from pytorch_optimizer.optimizer.apollo import Apollo
31-
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptSGD
31+
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptSGD
3232
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
3333
from pytorch_optimizer.optimizer.fp16 import DynamicLossScaler, SafeFP16Optimizer
3434
from pytorch_optimizer.optimizer.gc import centralize_gradient
@@ -45,6 +45,7 @@
4545
from pytorch_optimizer.optimizer.radam import RAdam
4646
from pytorch_optimizer.optimizer.ranger import Ranger
4747
from pytorch_optimizer.optimizer.ranger21 import Ranger21
48+
from pytorch_optimizer.optimizer.rotograd import RotoGrad
4849
from pytorch_optimizer.optimizer.sam import SAM
4950
from pytorch_optimizer.optimizer.sgdp import SGDP
5051
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
@@ -68,6 +69,7 @@
6869
clip_grad_norm,
6970
disable_running_stats,
7071
enable_running_stats,
72+
get_global_gradient_norm,
7173
get_optimizer_parameters,
7274
normalize_gradient,
7375
reduce_max_except_dim,
@@ -96,6 +98,7 @@
9698
DAdaptAdaGrad,
9799
DAdaptAdam,
98100
DAdaptSGD,
101+
DAdaptAdan,
99102
AdamS,
100103
AdaFactor,
101104
Apollo,

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class AdaFactor(Optimizer, BaseOptimizer):
1818
:param decay_rate: float. coefficient used to compute running averages of square gradient.
1919
:param weight_decay: float. weight decay (L2 penalty).
2020
:param clip_threshold: float. threshold of root-mean-square of final gradient update.
21+
:param amsgrad: bool. whether to use the AMSBound variant.
2122
:param scale_parameter: bool. if true, learning rate is scaled by root-mean-square of parameter.
2223
:param relative_step: bool. if true, time-dependent learning rate is computed instead of external learning rate.
2324
:param warmup_init: bool. time-dependent learning rate computation depends on whether warm-up initialization
@@ -34,6 +35,7 @@ def __init__(
3435
decay_rate: float = -0.8,
3536
weight_decay: float = 0.0,
3637
clip_threshold: float = 1.0,
38+
amsgrad: bool = False,
3739
scale_parameter: bool = True,
3840
relative_step: bool = True,
3941
warmup_init: bool = False,
@@ -45,6 +47,7 @@ def __init__(
4547
self.decay_rate = decay_rate
4648
self.weight_decay = weight_decay
4749
self.clip_threshold = clip_threshold
50+
self.amsgrad = amsgrad
4851
self.relative_step = relative_step
4952
self.eps1 = eps1
5053
self.eps2 = eps2
@@ -54,6 +57,7 @@ def __init__(
5457
defaults: DEFAULTS = {
5558
'lr': lr,
5659
'weight_decay': weight_decay,
60+
'amsgrad': amsgrad,
5761
'scale_parameter': scale_parameter,
5862
'relative_step': relative_step,
5963
'warmup_init': warmup_init,
@@ -94,6 +98,9 @@ def reset(self):
9498
else:
9599
state['exp_avg_sq'] = torch.zeros_like(grad)
96100

101+
if group['amsgrad']:
102+
state['exp_avg_sq_hat'] = torch.zeros_like(grad)
103+
97104
state['RMS'] = 0.0
98105

99106
def get_lr(
@@ -169,6 +176,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
169176
else:
170177
state['exp_avg_sq'] = torch.zeros_like(grad)
171178

179+
if group['amsgrad']:
180+
state['exp_avg_sq_hat'] = torch.zeros_like(grad)
181+
172182
state['RMS'] = 0.0
173183

174184
state['RMS'] = self.get_rms(p)
@@ -190,15 +200,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
190200
exp_avg_sq_row.mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
191201
exp_avg_sq_col.mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)
192202

193-
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output=update)
203+
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
194204
else:
195205
exp_avg_sq = state['exp_avg_sq']
196206
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
197207
torch.rsqrt(exp_avg_sq, out=update)
198208

199-
update.mul_(grad)
209+
if group['amsgrad']:
210+
exp_avg_sq_hat = state['exp_avg_sq_hat']
211+
torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
212+
torch.rsqrt(exp_avg_sq_hat / beta2_t, out=update)
200213

201-
# TODO: implement AMSGrad
214+
update.mul_(grad)
202215

203216
update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)
204217

pytorch_optimizer/optimizer/adan.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytorch_optimizer.base.optimizer import BaseOptimizer
99
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
1010
from pytorch_optimizer.optimizer.gc import centralize_gradient
11+
from pytorch_optimizer.optimizer.utils import get_global_gradient_norm
1112

1213

1314
class Adan(Optimizer, BaseOptimizer):
@@ -20,6 +21,8 @@ class Adan(Optimizer, BaseOptimizer):
2021
:param weight_decouple: bool. decoupled weight decay.
2122
:param max_grad_norm: float. max gradient norm to clip.
2223
:param use_gc: bool. use gradient centralization.
24+
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
25+
:param adanorm: bool. whether to use the AdaNorm variant.
2326
:param eps: float. term added to the denominator to improve numerical stability.
2427
"""
2528

@@ -32,6 +35,8 @@ def __init__(
3235
weight_decouple: bool = False,
3336
max_grad_norm: float = 0.0,
3437
use_gc: bool = False,
38+
r: float = 0.95,
39+
adanorm: bool = False,
3540
eps: float = 1e-8,
3641
):
3742
self.lr = lr
@@ -49,8 +54,12 @@ def __init__(
4954
'weight_decay': weight_decay,
5055
'weight_decouple': weight_decouple,
5156
'max_grad_norm': max_grad_norm,
57+
'adanorm': adanorm,
5258
'eps': eps,
5359
}
60+
if adanorm:
61+
defaults.update({'r': r})
62+
5463
super().__init__(params, defaults)
5564

5665
def validate_parameters(self):
@@ -71,25 +80,21 @@ def reset(self):
7180
state = self.state[p]
7281

7382
state['exp_avg'] = torch.zeros_like(p)
83+
state['exp_avg_sq'] = torch.zeros_like(p)
7484
state['exp_avg_diff'] = torch.zeros_like(p)
75-
state['exp_avg_nest'] = torch.zeros_like(p)
7685
state['previous_grad'] = torch.zeros_like(p)
86+
if group['adanorm']:
87+
state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
7788

7889
@torch.no_grad()
7990
def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
8091
if self.defaults['max_grad_norm'] == 0.0:
8192
return 1.0
8293

83-
global_grad_norm = torch.zeros(1, dtype=torch.float32, device=self.param_groups[0]['params'][0].device)
84-
85-
for group in self.param_groups:
86-
for p in group['params']:
87-
if p.grad is not None:
88-
global_grad_norm.add_(torch.linalg.norm(p.grad).pow(2))
89-
90-
global_grad_norm.sqrt_()
94+
global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
95+
global_grad_norm.sqrt_().add_(self.eps)
9196

92-
return torch.clamp(self.defaults['max_grad_norm'] / (global_grad_norm + self.eps), max=1.0)
97+
return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)
9398

9499
@torch.no_grad()
95100
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -122,35 +127,50 @@ def step(self, closure: CLOSURE = None) -> LOSS:
122127
state = self.state[p]
123128
if len(state) == 0:
124129
state['exp_avg'] = torch.zeros_like(p)
130+
state['exp_avg_sq'] = torch.zeros_like(p)
125131
state['exp_avg_diff'] = torch.zeros_like(p)
126-
state['exp_avg_nest'] = torch.zeros_like(p)
127-
state['previous_grad'] = grad.clone()
132+
state['previous_grad'] = grad.clone().mul_(-clip_global_grad_norm)
133+
if group['adanorm']:
134+
state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
128135

129136
grad.mul_(clip_global_grad_norm)
130137

131138
if self.use_gc:
132139
grad = centralize_gradient(grad, gc_conv_only=False)
133140

134-
grad_diff = -state['previous_grad']
141+
grad_diff = state['previous_grad']
135142
grad_diff.add_(grad)
136-
state['previous_grad'].copy_(grad)
137143

138-
update = grad + beta2 * grad_diff
144+
s_grad = grad
145+
if group['adanorm']:
146+
grad_norm = torch.linalg.norm(grad)
147+
148+
exp_grad_norm = state['exp_grad_norm']
149+
exp_grad_norm.mul_(group['r']).add_(grad_norm, alpha=1.0 - group['r'])
139150

140-
exp_avg, exp_avg_diff, exp_avg_nest = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_nest']
151+
if exp_grad_norm > grad_norm:
152+
s_grad *= exp_grad_norm / grad_norm
141153

142-
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
154+
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']
155+
156+
exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
143157
exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)
144-
exp_avg_nest.mul_(beta3).addcmul_(update, update, value=1.0 - beta3)
145158

146-
de_nom = (exp_avg_nest.sqrt_() / bias_correction3_sq).add_(self.eps)
147-
perturb = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(de_nom)
159+
grad_diff.mul_(beta2).add_(grad)
160+
exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)
161+
162+
de_nom = exp_avg_sq.sqrt()
163+
de_nom.div_(bias_correction3_sq).add_(group['eps'])
148164

149165
if group['weight_decouple']:
150166
p.mul_(1.0 - group['lr'] * group['weight_decay'])
151-
p.add_(perturb, alpha=-group['lr'])
152-
else:
153-
p.add_(perturb, alpha=-group['lr'])
167+
168+
p.addcdiv_(exp_avg, de_nom, value=-group['lr'] / bias_correction1)
169+
p.addcdiv_(exp_avg_diff, de_nom, value=-group['lr'] * beta2 / bias_correction2)
170+
171+
if not group['weight_decouple']:
154172
p.div_(1.0 + group['lr'] * group['weight_decay'])
155173

174+
state['previous_grad'].copy_(-grad)
175+
156176
return loss

pytorch_optimizer/optimizer/alig.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytorch_optimizer.base.exception import NoClosureError, NoSparseGradientError
77
from pytorch_optimizer.base.optimizer import BaseOptimizer
88
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.optimizer.utils import get_global_gradient_norm
910

1011

1112
class AliG(Optimizer, BaseOptimizer):
@@ -60,14 +61,10 @@ def reset(self):
6061
@torch.no_grad()
6162
def compute_step_size(self, loss: float) -> float:
6263
r"""Compute step_size."""
63-
global_grad_norm: float = 0
64+
global_grad_norm = get_global_gradient_norm(self.param_groups, torch.device('cpu'))
65+
global_grad_norm.add_(self.eps)
6466

65-
for group in self.param_groups:
66-
for p in group['params']:
67-
if p.grad is not None:
68-
global_grad_norm += p.grad.norm(2.0).pow(2).item()
69-
70-
return loss / (global_grad_norm + self.eps)
67+
return loss / global_grad_norm.item()
7168

7269
@torch.no_grad()
7370
def step(self, closure: CLOSURE = None) -> LOSS:

0 commit comments

Comments
 (0)