@@ -61,15 +61,41 @@ def set_grad(self, grads):
6161 p .grad = grads [idx ]
6262 idx += 1
6363
64- def pc_backward (self , objectives : Iterable [nn .Module ]):
65- """Calculate the gradient of the parameters
66- :param objectives: Iterable[nn.Module]. a list of objectives
64+ def retrieve_grad (self ):
65+ """get the gradient of the parameters of the network with specific objective"""
66+ grad , shape , has_grad = [], [], []
67+ for group in self .optimizer .param_groups :
68+ for p in group ['params' ]:
69+ if p .grad is None :
70+ shape .append (p .shape )
71+ grad .append (torch .zeros_like (p ).to (p .device ))
72+ has_grad .append (torch .zeros_like (p ).to (p .device ))
73+ continue
74+
75+ shape .append (p .grad .shape )
76+ grad .append (p .grad .clone ())
77+ has_grad .append (torch .ones_like (p ).to (p .device ))
78+
79+ return grad , shape , has_grad
80+
81+ def pack_grad (self , objectives : Iterable [nn .Module ]):
82+ """pack the gradient of the parameters of the network for each objective
83+ :param objectives: Iterable[float]. a list of objectives
6784 :return:
6885 """
69- grads , shapes , has_grads = self .pack_grad (objectives )
70- pc_grad = self .project_conflicting (grads , has_grads )
71- pc_grad = self .un_flatten_grad (pc_grad , shapes [0 ])
72- self .set_grad (pc_grad )
86+ grads , shapes , has_grads = [], [], []
87+ for objective in objectives :
88+ self .zero_grad ()
89+
90+ objective .backward (retain_graph = True )
91+
92+ grad , shape , has_grad = self .retrieve_grad ()
93+
94+ grads .append (self .flatten_grad (grad ))
95+ has_grads .append (self .flatten_grad (has_grad ))
96+ shapes .append (shape )
97+
98+ return grads , shapes , has_grads
7399
74100 def project_conflicting (self , grads , has_grads ) -> torch .Tensor :
75101 """
@@ -99,40 +125,13 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
99125
100126 return merged_grad
101127
102- def retrieve_grad (self ):
103- """Get the gradient of the parameters of the network with specific objective
104- :return:
105- """
106- grad , shape , has_grad = [], [], []
107- for group in self .optimizer .param_groups :
108- for p in group ['params' ]:
109- if p .grad is None :
110- shape .append (p .shape )
111- grad .append (torch .zeros_like (p ).to (p .device ))
112- has_grad .append (torch .zeros_like (p ).to (p .device ))
113- continue
114-
115- shape .append (p .grad .shape )
116- grad .append (p .grad .clone ())
117- has_grad .append (torch .ones_like (p ).to (p .device ))
118-
119- return grad , shape , has_grad
120-
121- def pack_grad (self , objectives : Iterable [nn .Module ]):
122- """Pack the gradient of the parameters of the network for each objective
123- :param objectives: Iterable[float]. a list of objectives
128+ def pc_backward (self , objectives : Iterable [nn .Module ]):
129+ """calculate the gradient of the parameters
130+ :param objectives: Iterable[nn.Module]. a list of objectives
124131 :return:
125132 """
126- grads , shapes , has_grads = [], [], []
127- for objective in objectives :
128- self .zero_grad ()
129-
130- objective .backward (retain_graph = True )
131-
132- grad , shape , has_grad = self .retrieve_grad ()
133-
134- grads .append (self .flatten_grad (grad ))
135- has_grads .append (self .flatten_grad (has_grad ))
136- shapes .append (shape )
133+ grads , shapes , has_grads = self .pack_grad (objectives )
134+ pc_grad = self .project_conflicting (grads , has_grads )
135+ pc_grad = self .un_flatten_grad (pc_grad , shapes [0 ])
137136
138- return grads , shapes , has_grads
137+ self . set_grad ( pc_grad )
0 commit comments