@@ -81,31 +81,55 @@ def __call__(
8181 return None
8282
8383
84- class LinearBasedGramianComputer (GramianComputer ):
85- def __init__ (self , module : nn .Linear ):
84+ class ModuleBasedGramianComputer (GramianComputer , ABC ):
85+ def __init__ (self , module : nn .Module ):
8686 self .module = module
8787
8888 def __call__ (
8989 self ,
90- _ : tuple [Tensor , ...],
90+ rg_outputs : tuple [Tensor , ...],
9191 grad_outputs : tuple [Tensor , ...],
9292 args : tuple [PyTree , ...],
93- __ : dict [str , PyTree ],
94- ) -> Optional [Tensor ]:
95-
96- X = args [0 ]
97- dY = grad_outputs [0 ]
98-
99- gramian = ComputeGramian .apply (self ._compute_gramian , dY , X )
93+ kwargs : dict [str , PyTree ],
94+ ) -> Tensor :
95+ gramian = ComputeGramian .apply (
96+ self ._compute_gramian , rg_outputs , grad_outputs , args , kwargs
97+ )
10098 return gramian
10199
102- def _compute_gramian (self , dY1 : Tensor , dY2 : Tensor , X : Tensor ) -> Tensor :
100+ @abstractmethod
101+ def _compute_gramian (
102+ self ,
103+ rg_outputs : tuple [Tensor , ...],
104+ jac_outputs1 : tuple [Tensor , ...],
105+ jac_outputs2 : tuple [Tensor , ...],
106+ args : tuple [PyTree , ...],
107+ kwargs : dict [str , PyTree ],
108+ ) -> Tensor :
103109 """
104- X is a matrix of shape [k, n] and dY1, dY2 are matrices of shape [k, m].
105- Returns the dY1 @ G @ dY2 where G is the Gramian of the Jacobian of the module output w.r.t.
106- to the module params.
110+ If G is the Gramian of the Jacobian of the model's output w.r.t. the parameters, and J1, J2
111+ are the jac_outputs ( Jacobian of losses w.r.t. outputs), then this should compute the matrix
112+ J1 @ G @ J2.T
107113 """
108114
115+
116+ class LinearBasedGramianComputer (ModuleBasedGramianComputer ):
117+ def __init__ (self , module : nn .Linear ):
118+ super ().__init__ (module )
119+
120+ def _compute_gramian (
121+ self ,
122+ _ : tuple [Tensor , ...],
123+ jac_outputs1 : tuple [Tensor , ...],
124+ jac_outputs2 : tuple [Tensor , ...],
125+ args : tuple [PyTree , ...],
126+ __ : dict [str , PyTree ],
127+ ) -> Tensor :
128+
129+ X = args [0 ]
130+ dY1 = jac_outputs1 [0 ]
131+ dY2 = jac_outputs2 [0 ]
132+
109133 # TODO: add support for ndim==4 or find solution that works for any ndim.
110134 if dY1 .ndim == 2 :
111135 G = torch .einsum (dY1 , [0 , 2 ], X , [0 , 3 ], X , [1 , 3 ], dY2 , [1 , 2 ], [0 , 1 ])
@@ -124,33 +148,46 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
124148class ComputeGramian (torch .autograd .Function ):
125149 @staticmethod
126150 def forward (
127- compute_gramian_fn : Callable [[Tensor , Tensor , Tensor ], Tensor ],
128- dY : Tensor ,
129- X : Tensor ,
151+ compute_gramian_fn : Callable [
152+ [
153+ tuple [Tensor , ...],
154+ tuple [Tensor , ...],
155+ tuple [Tensor , ...],
156+ tuple [PyTree , ...],
157+ dict [str , PyTree ],
158+ ],
159+ Tensor ,
160+ ],
161+ rg_outputs : tuple [Tensor , ...],
162+ grad_outputs : tuple [Tensor , ...],
163+ args : tuple [PyTree , ...],
164+ kwargs : dict [str , PyTree ],
130165 ) -> Tensor :
131166 # There is no non-batched dimension
132- gramian = compute_gramian_fn (dY , dY , X )
167+ gramian = compute_gramian_fn (rg_outputs , grad_outputs , grad_outputs , args , kwargs )
133168 return gramian
134169
135170 @staticmethod
136171 def vmap (
137172 _ ,
138- in_dims : tuple [None , tuple [int , ...], None ],
139- compute_gramian_fn : Callable [[Tensor , Tensor , Tensor ], Tensor ],
140- dY : Tensor ,
141- X : Tensor ,
173+ in_dims : tuple [None , None , tuple [int , ...], None , None ],
174+ compute_gramian_fn : Callable ,
175+ rg_outputs : tuple [Tensor , ...],
176+ jac_outputs : tuple [Tensor , ...],
177+ args : tuple [PyTree , ...],
178+ kwargs : dict [str , PyTree ],
142179 ) -> tuple [Tensor , None ]:
143180 # There is a non-batched dimension
144181 generalized_gramian = torch .vmap (
145182 torch .vmap (
146183 compute_gramian_fn ,
147- in_dims = (in_dims [1 ] , None , None ),
184+ in_dims = (None , in_dims [2 ], None , None , None ),
148185 out_dims = 0 ,
149186 ),
150- in_dims = (None , in_dims [1 ] , None ),
187+ in_dims = (None , None , in_dims [2 ], None , None ),
151188 out_dims = - 1 ,
152- )(dY , dY , X )
153- shape = dY .shape
189+ )(rg_outputs , jac_outputs , jac_outputs , args , kwargs )
190+ shape = generalized_gramian .shape
154191 gramian = reshape_gramian (generalized_gramian , [shape [0 ] * shape [1 ]])
155192 return gramian , None
156193
0 commit comments