Skip to content

Commit 33a9721

Browse files
committed
ndim=1 cannot happen (we are batched for now).
1 parent 0f1c909 commit 33a9721

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/torchjd/autogram/_gramian_computer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,7 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
107107
"""
108108

109109
# TODO: add support for ndim==4 or find solution that works for any ndim.
110-
if dY1.ndim == 1:
111-
# TODO: not sure that this even works
112-
G_b = torch.einsum("k,k->", dY1, dY2)
113-
G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
114-
elif dY1.ndim == 2:
110+
if dY1.ndim == 2:
115111
G_b = torch.einsum("ak,ik->ai", dY1, dY2)
116112
G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
117113
elif dY1.ndim == 3: # Typical in transformers

0 commit comments

Comments
 (0)