Skip to content

Commit eb4f677

Browse files
authored
Merge pull request #11 from yoyololicon/feat/remove-einsum
feat: remove einsum for efficiency
2 parents 01650aa + 4d9c5bb commit eb4f677

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
1313
# dimensions. Dimensions 3 and higher will have the same shape after multiplication.
1414
# We also allow for "grouped" multiplications, where multiple sections of channels
1515
# are multiplied independently of one another (required for group convolutions).
16-
scalar_matmul = partial(torch.einsum, "agc..., gbc... -> agb...")
1716
a = a.view(a.size(0), groups, -1, *a.shape[2:])
1817
b = b.view(groups, -1, *b.shape[1:])
1918

20-
# Compute the real and imaginary parts independently, then manually insert them
21-
# into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
22-
# because Autograd is not enabled for complex matrix operations yet. Not exactly
23-
# idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).
24-
real = scalar_matmul(a.real, b.real) - scalar_matmul(a.imag, b.imag)
25-
imag = scalar_matmul(a.imag, b.real) + scalar_matmul(a.real, b.imag)
19+
a = torch.movedim(a, 2, a.dim() - 1).unsqueeze(-2)
20+
b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
21+
22+
# complex value matrix multiplication
23+
real = a.real @ b.real - a.imag @ b.imag
24+
imag = a.imag @ b.real + a.real @ b.imag
25+
real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
26+
imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
2627
c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
2728
c.real, c.imag = real, imag
2829

0 commit comments

Comments
 (0)