@@ -113,38 +113,6 @@ def _compute_gramian(
113113 """
114114
115115
116- class LinearBasedGramianComputer (ModuleBasedGramianComputer ):
117- def __init__ (self , module : nn .Linear ):
118- super ().__init__ (module )
119-
120- def _compute_gramian (
121- self ,
122- _ : tuple [Tensor , ...],
123- jac_outputs1 : tuple [Tensor , ...],
124- jac_outputs2 : tuple [Tensor , ...],
125- args : tuple [PyTree , ...],
126- __ : dict [str , PyTree ],
127- ) -> Tensor :
128-
129- X = args [0 ]
130- dY1 = jac_outputs1 [0 ]
131- dY2 = jac_outputs2 [0 ]
132-
133- # TODO: add support for ndim==4 or find solution that works for any ndim.
134- if dY1 .ndim == 2 :
135- G = torch .einsum (dY1 , [0 , 2 ], X , [0 , 3 ], X , [1 , 3 ], dY2 , [1 , 2 ], [0 , 1 ])
136- if self .module .bias is not None :
137- G += torch .einsum (dY1 , [0 , 2 ], dY2 , [1 , 2 ], [0 , 1 ])
138- elif dY1 .ndim == 3 : # Typical in transformers
139- G = torch .einsum (dY1 , [0 , 2 , 4 ], X , [0 , 2 , 5 ], X , [1 , 3 , 5 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
140- if self .module .bias is not None :
141- G += torch .einsum (dY1 , [0 , 2 , 4 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
142- else :
143- raise ValueError ("Higher dimensions not supported. Open an issue if needed." )
144-
145- return G
146-
147-
148116class ComputeGramian (torch .autograd .Function ):
149117 @staticmethod
150118 def forward (
@@ -194,3 +162,35 @@ def vmap(
194162 @staticmethod
195163 def setup_context (* _ ) -> None :
196164 pass
165+
166+
167+ class LinearBasedGramianComputer (ModuleBasedGramianComputer ):
168+ def __init__ (self , module : nn .Linear ):
169+ super ().__init__ (module )
170+
171+ def _compute_gramian (
172+ self ,
173+ _ : tuple [Tensor , ...],
174+ jac_outputs1 : tuple [Tensor , ...],
175+ jac_outputs2 : tuple [Tensor , ...],
176+ args : tuple [PyTree , ...],
177+ __ : dict [str , PyTree ],
178+ ) -> Tensor :
179+
180+ X = args [0 ]
181+ dY1 = jac_outputs1 [0 ]
182+ dY2 = jac_outputs2 [0 ]
183+
184+ # TODO: add support for ndim==4 or find solution that works for any ndim.
185+ if dY1 .ndim == 2 :
186+ G = torch .einsum (dY1 , [0 , 2 ], X , [0 , 3 ], X , [1 , 3 ], dY2 , [1 , 2 ], [0 , 1 ])
187+ if self .module .bias is not None :
188+ G += torch .einsum (dY1 , [0 , 2 ], dY2 , [1 , 2 ], [0 , 1 ])
189+ elif dY1 .ndim == 3 : # Typical in transformers
190+ G = torch .einsum (dY1 , [0 , 2 , 4 ], X , [0 , 2 , 5 ], X , [1 , 3 , 5 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
191+ if self .module .bias is not None :
192+ G += torch .einsum (dY1 , [0 , 2 , 4 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
193+ else :
194+ raise ValueError ("Higher dimensions not supported. Open an issue if needed." )
195+
196+ return G
0 commit comments