Skip to content

Commit 433939f

Browse files
committed
refactor: PCGrad
1 parent f3e5a2b commit 433939f

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

pytorch_optimizer/pcgrad.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,41 @@ def set_grad(self, grads):
6161
p.grad = grads[idx]
6262
idx += 1
6363

64-
def pc_backward(self, objectives: Iterable[nn.Module]):
65-
"""Calculate the gradient of the parameters
66-
:param objectives: Iterable[nn.Module]. a list of objectives
64+
def retrieve_grad(self):
65+
"""get the gradient of the parameters of the network with specific objective"""
66+
grad, shape, has_grad = [], [], []
67+
for group in self.optimizer.param_groups:
68+
for p in group['params']:
69+
if p.grad is None:
70+
shape.append(p.shape)
71+
grad.append(torch.zeros_like(p).to(p.device))
72+
has_grad.append(torch.zeros_like(p).to(p.device))
73+
continue
74+
75+
shape.append(p.grad.shape)
76+
grad.append(p.grad.clone())
77+
has_grad.append(torch.ones_like(p).to(p.device))
78+
79+
return grad, shape, has_grad
80+
81+
def pack_grad(self, objectives: Iterable[nn.Module]):
82+
"""pack the gradient of the parameters of the network for each objective
83+
:param objectives: Iterable[float]. a list of objectives
6784
:return:
6885
"""
69-
grads, shapes, has_grads = self.pack_grad(objectives)
70-
pc_grad = self.project_conflicting(grads, has_grads)
71-
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
72-
self.set_grad(pc_grad)
86+
grads, shapes, has_grads = [], [], []
87+
for objective in objectives:
88+
self.zero_grad()
89+
90+
objective.backward(retain_graph=True)
91+
92+
grad, shape, has_grad = self.retrieve_grad()
93+
94+
grads.append(self.flatten_grad(grad))
95+
has_grads.append(self.flatten_grad(has_grad))
96+
shapes.append(shape)
97+
98+
return grads, shapes, has_grads
7399

74100
def project_conflicting(self, grads, has_grads) -> torch.Tensor:
75101
"""
@@ -99,40 +125,13 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
99125

100126
return merged_grad
101127

102-
def retrieve_grad(self):
103-
"""Get the gradient of the parameters of the network with specific objective
104-
:return:
105-
"""
106-
grad, shape, has_grad = [], [], []
107-
for group in self.optimizer.param_groups:
108-
for p in group['params']:
109-
if p.grad is None:
110-
shape.append(p.shape)
111-
grad.append(torch.zeros_like(p).to(p.device))
112-
has_grad.append(torch.zeros_like(p).to(p.device))
113-
continue
114-
115-
shape.append(p.grad.shape)
116-
grad.append(p.grad.clone())
117-
has_grad.append(torch.ones_like(p).to(p.device))
118-
119-
return grad, shape, has_grad
120-
121-
def pack_grad(self, objectives: Iterable[nn.Module]):
122-
"""Pack the gradient of the parameters of the network for each objective
123-
:param objectives: Iterable[float]. a list of objectives
128+
def pc_backward(self, objectives: Iterable[nn.Module]):
129+
"""calculate the gradient of the parameters
130+
:param objectives: Iterable[nn.Module]. a list of objectives
124131
:return:
125132
"""
126-
grads, shapes, has_grads = [], [], []
127-
for objective in objectives:
128-
self.zero_grad()
129-
130-
objective.backward(retain_graph=True)
131-
132-
grad, shape, has_grad = self.retrieve_grad()
133-
134-
grads.append(self.flatten_grad(grad))
135-
has_grads.append(self.flatten_grad(has_grad))
136-
shapes.append(shape)
133+
grads, shapes, has_grads = self.pack_grad(objectives)
134+
pc_grad = self.project_conflicting(grads, has_grads)
135+
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
137136

138-
return grads, shapes, has_grads
137+
self.set_grad(pc_grad)

0 commit comments

Comments
 (0)