Skip to content

Commit 47363bd

Browse files
committed
Use interleaved input style for einsum.
1 parent 33a9721 commit 47363bd

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
330330
non_batch_dim_len = output.shape[0]
331331
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
332332
ones = torch.ones_like(output[0])
333-
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
333+
jac_output = torch.einsum(identity_matrix, [0, 1], ones, [...], [0, 1, ...])
334334

335335
_ = vmap(differentiation)(jac_output)
336336
else:

src/torchjd/autogram/_gramian_computer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def _compute_gramian(self, dY1: Tensor, dY2: Tensor, X: Tensor) -> Tensor:
108108

109109
# TODO: add support for ndim==4 or find solution that works for any ndim.
110110
if dY1.ndim == 2:
111-
G_b = torch.einsum("ak,ik->ai", dY1, dY2)
112-
G_W = torch.einsum("ak,al,il,ik->ai", dY1, X, X, dY2)
111+
G_b = torch.einsum(dY1, [0, 2], dY2, [1, 2], [0, 1])
112+
G_W = torch.einsum(dY1, [0, 2], X, [0, 3], X, [1, 3], dY2, [1, 2], [0, 1])
113113
elif dY1.ndim == 3: # Typical in transformers
114-
G_b = torch.einsum("abk,ijk->ai", dY1, dY2)
115-
G_W = torch.einsum("abk,abl,ijl,ijk->ai", dY1, X, X, dY2)
114+
G_b = torch.einsum(dY1, [0, 2, 4], dY2, [1, 3, 4], [0, 1])
115+
G_W = torch.einsum(dY1, [0, 2, 4], X, [0, 2, 5], X, [1, 3, 5], dY2, [1, 3, 4], [0, 1])
116116
else:
117117
raise ValueError("Higher dimensions not supported. Open an issue if needed.")
118118

0 commit comments

Comments
 (0)