@@ -13,16 +13,17 @@ def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
1313 # dimensions. Dimensions 3 and higher will have the same shape after multiplication.
1414 # We also allow for "grouped" multiplications, where multiple sections of channels
1515 # are multiplied independently of one another (required for group convolutions).
16- scalar_matmul = partial (torch .einsum , "agc..., gbc... -> agb..." )
1716 a = a .view (a .size (0 ), groups , - 1 , * a .shape [2 :])
1817 b = b .view (groups , - 1 , * b .shape [1 :])
1918
20- # Compute the real and imaginary parts independently, then manually insert them
21- # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
22- # because Autograd is not enabled for complex matrix operations yet. Not exactly
23- # idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).
24- real = scalar_matmul (a .real , b .real ) - scalar_matmul (a .imag , b .imag )
25- imag = scalar_matmul (a .imag , b .real ) + scalar_matmul (a .real , b .imag )
19+ a = torch .movedim (a , 2 , a .dim () - 1 ).unsqueeze (- 2 )
20+ b = torch .movedim (b , (1 , 2 ), (b .dim () - 1 , b .dim () - 2 ))
21+
22+ # complex value matrix multiplication
23+ real = a .real @ b .real - a .imag @ b .imag
24+ imag = a .imag @ b .real + a .real @ b .imag
25+ real = torch .movedim (real , real .dim () - 1 , 2 ).squeeze (- 1 )
26+ imag = torch .movedim (imag , imag .dim () - 1 , 2 ).squeeze (- 1 )
2627 c = torch .zeros (real .shape , dtype = torch .complex64 , device = a .device )
2728 c .real , c .imag = real , imag
2829
0 commit comments