We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 33a9721 commit 47363bdCopy full SHA for 47363bd
src/torchjd/autogram/_engine.py
@@ -330,7 +330,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
330
non_batch_dim_len = output.shape[0]
331
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
332
ones = torch.ones_like(output[0])
333
- jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
+ jac_output = torch.einsum(identity_matrix, [0, 1], ones, [...], [0, 1, ...])
334
335
_ = vmap(differentiation)(jac_output)
336
else:
src/torchjd/autogram/_gramian_computer.py
@@ -108,11 +108,11 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
108
109
# TODO: add support for ndim==4 or find solution that works for any ndim.
110
if dY1.ndim == 2:
111
- G_b = torch.einsum("ak,ik->ai", dY1, dY2)
112
- G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
+ G_b = torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
+ G_W = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
113
elif dY1.ndim == 3: # Typical in transformers
114
- G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
115
- G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
+ G_b = torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
+ G_W = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
116
117
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
118
0 commit comments