@@ -108,15 +108,17 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
108108
109109 # TODO: add support for ndim==4 or find solution that works for any ndim.
110110 if dY1 .ndim == 2 :
111- G_b = torch .einsum (dY1 , [0 , 2 ], dY2 , [1 , 2 ], [0 , 1 ])
112- G_W = torch .einsum (dY1 , [0 , 2 ], X , [0 , 3 ], X , [1 , 3 ], dY2 , [1 , 2 ], [0 , 1 ])
111+ G = torch .einsum (dY1 , [0 , 2 ], X , [0 , 3 ], X , [1 , 3 ], dY2 , [1 , 2 ], [0 , 1 ])
112+ if self .module .bias is not None :
113+ G += torch .einsum (dY1 , [0 , 2 ], dY2 , [1 , 2 ], [0 , 1 ])
113114 elif dY1 .ndim == 3 : # Typical in transformers
114- G_b = torch .einsum (dY1 , [0 , 2 , 4 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
115- G_W = torch .einsum (dY1 , [0 , 2 , 4 ], X , [0 , 2 , 5 ], X , [1 , 3 , 5 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
115+ G = torch .einsum (dY1 , [0 , 2 , 4 ], X , [0 , 2 , 5 ], X , [1 , 3 , 5 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
116+ if self .module .bias is not None :
117+ G += torch .einsum (dY1 , [0 , 2 , 4 ], dY2 , [1 , 3 , 4 ], [0 , 1 ])
116118 else :
117119 raise ValueError ("Higher dimensions not supported. Open an issue if needed." )
118120
119- return G_b + G_W
121+ return G
120122
121123
122124class ComputeLinearGramian (torch .autograd .Function ):
0 commit comments