Skip to content

Commit 5af762d

Browse files
committed
update: orthograd
1 parent 86aa3eb commit 5af762d

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

pytorch_optimizer/optimizer/experimental/ranger25.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class Ranger25(BaseOptimizer):
1212
r"""Mixin' every fancy optimizer hacks.
1313
14-
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
14+
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2 + OrthoGrad
1515
1616
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1717
:param lr: float. learning rate.
@@ -23,6 +23,7 @@ class Ranger25(BaseOptimizer):
2323
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
2424
:param cautious: bool. whether to use the Cautious variant.
2525
:param stable_adamw: bool. whether to use stable AdamW variant.
26+
:param orthograd: bool. whether to use orthograd variant.
2627
:param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
2728
stable_adamw is False, adam-atan2 feature will be used.
2829
"""
@@ -39,6 +40,7 @@ def __init__(
3940
t_alpha_beta3: Optional[float] = None,
4041
cautious: bool = True,
4142
stable_adamw: bool = True,
43+
orthograd: bool = True,
4244
eps: Optional[float] = 1e-8,
4345
**kwargs,
4446
):
@@ -51,6 +53,7 @@ def __init__(
5153

5254
self.cautious = cautious
5355
self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False
56+
self.orthograd = orthograd
5457

5558
defaults: DEFAULTS = {
5659
'lr': lr,
@@ -97,13 +100,32 @@ def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta
97100
beta3,
98101
)
99102

103+
@torch.no_grad()
104+
def orthogonalize_gradients(self, params, eps: float = 1e-16) -> None:
105+
for p in params:
106+
if p.grad is None:
107+
continue
108+
109+
w = p.view(-1)
110+
g = p.grad.view(-1)
111+
112+
proj = torch.dot(w, g).div_(torch.dot(w, w).add_(eps))
113+
g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
114+
g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(eps)))
115+
116+
p.grad.copy_(g_ortho_scaled.view_as(p.grad))
117+
100118
@torch.no_grad()
101119
def step(self, closure: CLOSURE = None) -> LOSS:
102120
loss: LOSS = None
103121
if closure is not None:
104122
with torch.enable_grad():
105123
loss = closure()
106124

125+
if self.orthograd:
126+
for group in self.param_groups:
127+
self.orthogonalize_gradients(group['params'])
128+
107129
for group in self.param_groups:
108130
if 'step' in group:
109131
group['step'] += 1

0 commit comments

Comments
 (0)