Skip to content

Commit 8862c16

Browse files
committed
Add support for ndim==3 inputs / outputs.
1 parent 6e7d051 commit 8862c16

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/torchjd/autogram/_gramian_computer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,19 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
106106
to the module params.
107107
"""
108108

109-
G_b = torch.einsum("ik,jk->ij", dY1, dY2)
110-
G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)
109+
# TODO: add support for ndim==4 or find solution that works for any ndim.
110+
if dY1.ndim == 1:
111+
# TODO: not sure that this even works
112+
G_b = torch.einsum("k,k->", dY1, dY2)
113+
G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
114+
if dY1.ndim == 2:
115+
G_b = torch.einsum("ak,ik->ai", dY1, dY2)
116+
G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
117+
elif dY1.ndim == 3: # Typical in transformers
118+
G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
119+
G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
120+
else:
121+
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
111122

112123
return G_b + G_W
113124

0 commit comments

Comments
 (0)