Skip to content

Commit d5f5d0c

Browse files
authored
Merge pull request #32 from kozistr/feature/pc-grad
[Feature] Implement PCGrad
2 parents 4d328fa + 7bcc873 commit d5f5d0c

File tree

3 files changed

+167
-12
lines changed

3 files changed

+167
-12
lines changed

README.rst

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ of the ideas are applied in ``Ranger21`` optimizer.
7373

7474
Also, most of the captures are taken from ``Ranger21`` paper.
7575

76-
+------------------------------------------+-------------------------------------+--------------------------------------------+
77-
| `Adaptive Gradient Clipping`_ | `Gradient Centralization`_ | `Softplus Transformation`_ |
78-
+------------------------------------------+-------------------------------------+--------------------------------------------+
79-
| `Gradient Normalization`_ | `Norm Loss`_ | `Positive-Negative Momentum`_ |
80-
+------------------------------------------+-------------------------------------+--------------------------------------------+
81-
| `Linear learning rate warmup`_ | `Stable weight decay`_ | `Explore-exploit learning rate schedule`_ |
82-
+------------------------------------------+-------------------------------------+--------------------------------------------+
83-
| `Lookahead`_ | `Chebyshev learning rate schedule`_ | `(Adaptive) Sharpness-Aware Minimization`_ |
84-
+------------------------------------------+-------------------------------------+--------------------------------------------+
85-
| `On the Convergence of Adam and Beyond`_ | | |
86-
+------------------------------------------+-------------------------------------+--------------------------------------------+
76+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
77+
| `Adaptive Gradient Clipping`_ | `Gradient Centralization`_ | `Softplus Transformation`_ |
78+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
79+
| `Gradient Normalization`_ | `Norm Loss`_ | `Positive-Negative Momentum`_ |
80+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
81+
| `Linear learning rate warmup`_ | `Stable weight decay`_ | `Explore-exploit learning rate schedule`_ |
82+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
83+
| `Lookahead`_ | `Chebyshev learning rate schedule`_ | `(Adaptive) Sharpness-Aware Minimization`_ |
84+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
85+
| `On the Convergence of Adam and Beyond`_ | `Gradient Surgery for Multi-Task Learning`_ | | |
86+
+------------------------------------------+---------------------------------------------+--------------------------------------------+
8787

8888
Adaptive Gradient Clipping
8989
--------------------------
@@ -195,6 +195,11 @@ On the Convergence of Adam and Beyond
195195

196196
- paper : `paper <https://openreview.net/forum?id=ryQu7f-RZ>`__
197197

198+
Gradient Surgery for Multi-Task Learning
199+
----------------------------------------
200+
201+
- paper : `paper <https://arxiv.org/abs/2001.06782>`__
202+
198203
Citations
199204
---------
200205

@@ -430,6 +435,17 @@ On the Convergence of Adam and Beyond
430435
year={2019}
431436
}
432437

438+
Gradient Surgery for Multi-Task Learning
439+
440+
::
441+
442+
@article{yu2020gradient,
443+
title={Gradient surgery for multi-task learning},
444+
author={Yu, Tianhe and Kumar, Saurabh and Gupta, Abhishek and Levine, Sergey and Hausman, Karol and Finn, Chelsea},
445+
journal={arXiv preprint arXiv:2001.06782},
446+
year={2020}
447+
}
448+
433449
Author
434450
------
435451

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from pytorch_optimizer.gc import centralize_gradient
1111
from pytorch_optimizer.lookahead import Lookahead
1212
from pytorch_optimizer.madgrad import MADGRAD
13+
from pytorch_optimizer.pcgrad import PCGrad
1314
from pytorch_optimizer.radam import RAdam
1415
from pytorch_optimizer.ranger import Ranger
1516
from pytorch_optimizer.ranger21 import Ranger21
1617
from pytorch_optimizer.sam import SAM
1718
from pytorch_optimizer.sgdp import SGDP
1819

19-
__VERSION__ = '0.0.10'
20+
__VERSION__ = '0.1.0'

pytorch_optimizer/pcgrad.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import random
2+
from copy import deepcopy
3+
from typing import Iterable, List
4+
5+
import numpy as np
6+
import torch
7+
from torch import nn
8+
from torch.optim.optimizer import Optimizer
9+
10+
11+
class PCGrad:
12+
"""
13+
Reference : https://github.com/WeiChengTseng/Pytorch-PCGrad
14+
Example :
15+
from pytorch_optimizer import AdamP, PCGrad
16+
...
17+
model = YourModel()
18+
optimizer = PCGrad(AdamP(model.parameters()))
19+
20+
loss_1, loss_2 = nn.L1Loss(), nn.MSELoss()
21+
...
22+
for input, output in data:
23+
optimizer.zero_grad()
24+
loss1, loss2 = loss1_fn(y_pred, output), loss2_fn(y_pred, output)
25+
optimizer.pc_backward([loss1, loss2])
26+
optimizer.step()
27+
"""
28+
29+
def __init__(self, optimizer: Optimizer, reduction: str = 'mean'):
30+
self.optimizer = optimizer
31+
self.reduction = reduction
32+
33+
def check_valid_parameters(self):
34+
if self.reduction not in ('mean', 'sum'):
35+
raise ValueError(f'invalid reduction : {self.reduction}')
36+
37+
@staticmethod
38+
def flatten_grad(grads) -> torch.Tensor:
39+
return torch.cat([g.flatten() for g in grads])
40+
41+
@staticmethod
42+
def un_flatten_grad(grads, shapes) -> List[torch.Tensor]:
43+
un_flatten_grad = []
44+
idx: int = 0
45+
for shape in shapes:
46+
length = np.prod(shape)
47+
un_flatten_grad.append(grads[idx : idx + length].view(shape).clone())
48+
idx += length
49+
return un_flatten_grad
50+
51+
def zero_grad(self):
52+
return self.optimizer.zero_grad(set_to_none=True)
53+
54+
def step(self):
55+
return self.optimizer.step()
56+
57+
def set_grad(self, grads):
58+
idx: int = 0
59+
for group in self.optimizer.param_groups:
60+
for p in group['params']:
61+
p.grad = grads[idx]
62+
idx += 1
63+
64+
def pc_backward(self, objectives: Iterable[nn.Module]):
65+
"""Calculate the gradient of the parameters
66+
:param objectives: Iterable[nn.Module]. a list of objectives
67+
:return:
68+
"""
69+
grads, shapes, has_grads = self.pack_grad(objectives)
70+
pc_grad = self.project_conflicting(grads, has_grads)
71+
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
72+
self.set_grad(pc_grad)
73+
74+
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
75+
"""
76+
:param grads: a list of the gradient of the parameters
77+
:param has_grads: a list of mask represent whether the parameter has gradient
78+
:return:
79+
"""
80+
shared = torch.stack(has_grads).prod(0).bool()
81+
82+
pc_grad = deepcopy(grads)
83+
for g_i in pc_grad:
84+
random.shuffle(grads)
85+
for g_j in grads:
86+
g_i_g_j = torch.dot(g_i, g_j)
87+
if g_i_g_j < 0:
88+
g_i -= g_i_g_j * g_j / (g_j.norm() ** 2)
89+
90+
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
91+
merged_grad[shared] = torch.stack([g[shared] for g in pc_grad])
92+
93+
if self.reduction == 'mean':
94+
merged_grad = merged_grad.mean(dim=0)
95+
else: # self.reduction == 'sum'
96+
merged_grad = merged_grad.sum(dim=0)
97+
98+
merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
99+
100+
return merged_grad
101+
102+
def retrieve_grad(self):
103+
"""Get the gradient of the parameters of the network with specific objective
104+
:return:
105+
"""
106+
grad, shape, has_grad = [], [], []
107+
for group in self.optimizer.param_groups:
108+
for p in group['params']:
109+
if p.grad is None:
110+
shape.append(p.shape)
111+
grad.append(torch.zeros_like(p).to(p.device))
112+
has_grad.append(torch.zeros_like(p).to(p.device))
113+
continue
114+
115+
shape.append(p.grad.shape)
116+
grad.append(p.grad.clone())
117+
has_grad.append(torch.ones_like(p).to(p.device))
118+
119+
return grad, shape, has_grad
120+
121+
def pack_grad(self, objectives: Iterable[nn.Module]):
122+
"""Pack the gradient of the parameters of the network for each objective
123+
:param objectives: Iterable[float]. a list of objectives
124+
:return:
125+
"""
126+
grads, shapes, has_grads = [], [], []
127+
for objective in objectives:
128+
self.zero_grad()
129+
130+
objective.backward(retain_graph=True)
131+
132+
grad, shape, has_grad = self.retrieve_grad()
133+
134+
grads.append(self.flatten_grad(grad))
135+
has_grads.append(self.flatten_grad(has_grad))
136+
shapes.append(shape)
137+
138+
return grads, shapes, has_grads

0 commit comments

Comments
 (0)