Skip to content

Commit 4c2374d

Browse files
committed
feat: 1.7 compatible implementation
1 parent 7e99110 commit 4c2374d

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
2020
b = torch.movedim(b, (1, 2), (b.dim() - 1, b.dim() - 2))
2121

2222
# complex value matrix multiplication
23-
c = a @ b
24-
c = torch.movedim(c, c.dim() - 1, 2)
25-
return c.reshape(c.size(0), -1, *c.shape[3:-1])
23+
real = a.real @ b.real - a.imag @ b.imag
24+
imag = a.imag @ b.real - a.real @ b.imag
25+
real = torch.movedim(real, real.dim() - 1, 2).squeeze(-1)
26+
imag = torch.movedim(imag, imag.dim() - 1, 2).squeeze(-1)
27+
c = torch.zeros(real.shape, dtype=torch.complex64, device=a.device)
28+
c.real, c.imag = real, imag
29+
30+
return c.reshape(c.size(0), -1, *c.shape[3:])
2631

2732

2833
def to_ntuple(val: Union[int, Iterable[int]], n: int) -> Tuple[int, ...]:

0 commit comments

Comments
 (0)