Skip to content

Commit 99e4c78

Browse files
committed
Handle bias is None in LinearBasedGramianComputer
1 parent 47363bd commit 99e4c78

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/torchjd/autogram/_gramian_computer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,17 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
108108

109109
# TODO: add support for ndim==4 or find solution that works for any ndim.
110110
if dY1.ndim == 2:
111-
G_b = torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
112-
G_W = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
111+
G = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
112+
if self.module.bias is not None:
113+
G += torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
113114
elif dY1.ndim == 3: # Typical in transformers
114-
G_b = torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
115-
G_W = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
115+
G = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
116+
if self.module.bias is not None:
117+
G += torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
116118
else:
117119
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
118120

119-
return G_b + G_W
121+
return G
120122

121123

122124
class ComputeLinearGramian(torch.autograd.Function):

0 commit comments

Comments
 (0)