We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6e7d051 commit 8862c16Copy full SHA for 8862c16
src/torchjd/autogram/_gramian_computer.py
@@ -106,8 +106,19 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
106
to the module params.
107
"""
108
109
- G_b = torch.einsum("ik,jk->ij", dY1, dY2)
110
- G_W = torch.einsum("ik,il,jl,jk->ij", dY1, X, X, dY2)
+ # TODO: add support for ndim==4 or find solution that works for any ndim.
+ if dY1.ndim == 1:
111
+ # TODO: not sure that this even works
112
+ G_b = torch.einsum("k,k->", dY1, dY2)
113
+ G_W = torch.einsum("k,l,l,k->", dY1, X, X, dY2)
114
+ if dY1.ndim == 2:
115
+ G_b = torch.einsum("ak,ik->ai", dY1, dY2)
116
+ G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
117
+ elif dY1.ndim == 3: # Typical in transformers
118
+ G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
119
+ G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
120
+ else:
121
+ raise ValueError("Higher dimensions not supported. Open an issue if needed.")
122
123
return G_b + G_W
124
0 commit comments