11import random
22from copy import deepcopy
3- from typing import Iterable , List
3+ from typing import Iterable , List , Tuple
44
55import numpy as np
66import torch
@@ -35,12 +35,12 @@ def check_valid_parameters(self):
3535 raise ValueError (f'invalid reduction : { self .reduction } ' )
3636
3737 @staticmethod
38- def flatten_grad (grads ) -> torch .Tensor :
38+ def flatten_grad (grads : List [ torch . Tensor ] ) -> torch .Tensor :
3939 return torch .cat ([g .flatten () for g in grads ])
4040
4141 @staticmethod
4242 def un_flatten_grad (grads , shapes ) -> List [torch .Tensor ]:
43- un_flatten_grad = []
43+ un_flatten_grad : List [ torch . Tensor ] = []
4444 idx : int = 0
4545 for shape in shapes :
4646 length = np .prod (shape )
@@ -54,39 +54,40 @@ def zero_grad(self):
5454 def step (self ):
5555 return self .optimizer .step ()
5656
57- def set_grad (self , grads ):
57+ def set_grad (self , grads : List [ torch . Tensor ] ):
5858 idx : int = 0
5959 for group in self .optimizer .param_groups :
6060 for p in group ['params' ]:
6161 p .grad = grads [idx ]
6262 idx += 1
6363
64- def retrieve_grad (self ):
64+ def retrieve_grad (self ) -> Tuple [ List [ torch . Tensor ], List [ int ], List [ torch . Tensor ]] :
6565 """get the gradient of the parameters of the network with specific objective"""
6666 grad , shape , has_grad = [], [], []
6767 for group in self .optimizer .param_groups :
6868 for p in group ['params' ]:
6969 if p .grad is None :
7070 shape .append (p .shape )
71- grad .append (torch .zeros_like (p ). to ( p .device ))
72- has_grad .append (torch .zeros_like (p ). to ( p .device ))
71+ grad .append (torch .zeros_like (p , device = p .device ))
72+ has_grad .append (torch .zeros_like (p , device = p .device ))
7373 continue
7474
7575 shape .append (p .grad .shape )
7676 grad .append (p .grad .clone ())
77- has_grad .append (torch .ones_like (p ). to ( p .device ))
77+ has_grad .append (torch .ones_like (p , device = p .device ))
7878
7979 return grad , shape , has_grad
8080
81- def pack_grad (self , objectives : Iterable [nn .Module ]):
81+ def pack_grad (
82+ self , objectives : Iterable [nn .Module ]
83+ ) -> Tuple [List [torch .Tensor ], List [List [int ]], List [torch .Tensor ]]:
8284 """pack the gradient of the parameters of the network for each objective
83- :param objectives: Iterable[float ]. a list of objectives
85+ :param objectives: Iterable[nn.Module ]. a list of objectives
8486 :return:
8587 """
8688 grads , shapes , has_grads = [], [], []
8789 for objective in objectives :
88- self .zero_grad ()
89-
90+ self .optimizer .zero_grad (set_to_none = True )
9091 objective .backward (retain_graph = True )
9192
9293 grad , shape , has_grad = self .retrieve_grad ()
@@ -98,7 +99,7 @@ def pack_grad(self, objectives: Iterable[nn.Module]):
9899 return grads , shapes , has_grads
99100
100101 def project_conflicting (self , grads , has_grads ) -> torch .Tensor :
101- """
102+ """project conflicting
102103 :param grads: a list of the gradient of the parameters
103104 :param has_grads: a list of mask represent whether the parameter has gradient
104105 :return:
@@ -114,12 +115,10 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:
114115 g_i -= g_i_g_j * g_j / (g_j .norm () ** 2 )
115116
116117 merged_grad = torch .zeros_like (grads [0 ]).to (grads [0 ].device )
117- merged_grad [shared ] = torch .stack ([g [shared ] for g in pc_grad ])
118-
119118 if self .reduction == 'mean' :
120- merged_grad = merged_grad .mean (dim = 0 )
121- else : # self.reduction == 'sum'
122- merged_grad = merged_grad .sum (dim = 0 )
119+ merged_grad [ shared ] = torch . stack ([ g [ shared ] for g in pc_grad ]) .mean (dim = 0 )
120+ else :
121+ merged_grad [ shared ] = torch . stack ([ g [ shared ] for g in pc_grad ]) .sum (dim = 0 )
123122
124123 merged_grad [~ shared ] = torch .stack ([g [~ shared ] for g in pc_grad ]).sum (dim = 0 )
125124
0 commit comments