@@ -39,15 +39,15 @@ def validate_parameters(self):
3939
4040 @torch .no_grad ()
4141 def reset (self ):
42- pass
42+ self . zero_grad ()
4343
4444 def zero_grad (self ):
4545 return self .optimizer .zero_grad (set_to_none = True )
4646
4747 def step (self ):
4848 return self .optimizer .step ()
4949
50- def set_grad (self , grads ):
50+ def set_grad (self , grads : List [ torch . Tensor ] ):
5151 idx : int = 0
5252 for group in self .optimizer .param_groups :
5353 for p in group ['params' ]:
@@ -74,7 +74,7 @@ def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tenso
7474 def pack_grad (self , objectives : Iterable ) -> Tuple [List [torch .Tensor ], List [List [int ]], List [torch .Tensor ]]:
7575 """pack the gradient of the parameters of the network for each objective
7676 :param objectives: Iterable[nn.Module]. a list of objectives
77- :return:
77+ :return: torch.Tensor. packed gradients
7878 """
7979 grads , shapes , has_grads = [], [], []
8080 for objective in objectives :
@@ -89,27 +89,29 @@ def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List
8989
9090 return grads , shapes , has_grads
9191
92- def project_conflicting (self , grads , has_grads ) -> torch .Tensor :
92+ def project_conflicting (self , grads : List [ torch . Tensor ] , has_grads : List [ torch . Tensor ] ) -> torch .Tensor :
9393 """project conflicting
9494 :param grads: a list of the gradient of the parameters
9595 :param has_grads: a list of mask represent whether the parameter has gradient
96- :return:
96+ :return: torch.Tensor. merged gradients
9797 """
98- shared = torch .stack (has_grads ).prod (0 ).bool ()
98+ shared : torch . Tensor = torch .stack (has_grads ).prod (0 ).bool ()
9999
100- pc_grad = deepcopy (grads )
100+ pc_grad : List [ torch . Tensor ] = deepcopy (grads )
101101 for g_i in pc_grad :
102102 random .shuffle (grads )
103103 for g_j in grads :
104- g_i_g_j = torch .dot (g_i , g_j )
104+ g_i_g_j : torch . Tensor = torch .dot (g_i , g_j )
105105 if g_i_g_j < 0 :
106106 g_i -= g_i_g_j * g_j / (g_j .norm () ** 2 )
107107
108- merged_grad = torch .zeros_like (grads [0 ]).to (grads [0 ].device )
108+ merged_grad : torch .Tensor = torch .zeros_like (grads [0 ], device = grads [0 ].device )
109+
110+ shared_pc_gradients : torch .Tensor = torch .stack ([g [shared ] for g in pc_grad ])
109111 if self .reduction == 'mean' :
110- merged_grad [shared ] = torch . stack ([ g [ shared ] for g in pc_grad ]) .mean (dim = 0 )
112+ merged_grad [shared ] = shared_pc_gradients .mean (dim = 0 )
111113 else :
112- merged_grad [shared ] = torch . stack ([ g [ shared ] for g in pc_grad ]) .sum (dim = 0 )
114+ merged_grad [shared ] = shared_pc_gradients .sum (dim = 0 )
113115
114116 merged_grad [~ shared ] = torch .stack ([g [~ shared ] for g in pc_grad ]).sum (dim = 0 )
115117
@@ -121,7 +123,7 @@ def pc_backward(self, objectives: Iterable[nn.Module]):
121123 :return:
122124 """
123125 grads , shapes , has_grads = self .pack_grad (objectives )
126+
124127 pc_grad = self .project_conflicting (grads , has_grads )
125128 pc_grad = un_flatten_grad (pc_grad , shapes [0 ])
126-
127129 self .set_grad (pc_grad )
0 commit comments