Skip to content

Commit 628e0b0

Browse files
refactor(autogram): Create block-diagonal matrix using einsum (#456)
1 parent 5595f5d commit 628e0b0

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]:
319319
if has_non_batch_dim:
320320
# There is one non-batched dimension, it is the first one
321321
non_batch_dim_len = output.shape[0]
322-
jac_output_shape = [output.shape[0]] + list(output.shape)
323-
324-
jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype)
325-
for i in range(non_batch_dim_len):
326-
jac_output[i, i, ...] = 1.0
322+
identity_matrix = torch.eye(non_batch_dim_len, device=output.device, dtype=output.dtype)
323+
ones = torch.ones_like(output[0])
324+
jac_output = torch.einsum("ij, ... -> ij...", identity_matrix, ones)
327325

328326
_ = vmap(differentiation)(jac_output)
329327
else:

0 commit comments

Comments
 (0)