77import torch .nn .functional as f
88
99
10- def complex_matmul (a : Tensor , b : Tensor ) -> Tensor :
10+ def complex_matmul (a : Tensor , b : Tensor , groups : int = 1 ) -> Tensor :
1111 """Multiplies two complex-valued tensors."""
12- # Scalar matrix multiplication of two tensors, over only the first two dimensions.
13- # Dimensions 3 and higher will have the same shape after multiplication.
14- scalar_matmul = partial (torch .einsum , "ab..., cb... -> ac..." )
12+ # Scalar matrix multiplication of two tensors, over only the first channel
13+ # dimensions. Dimensions 3 and higher will have the same shape after multiplication.
14+ # We also allow for "grouped" multiplications, where multiple sections of channels
15+ # are multiplied independently of one another (required for group convolutions).
16+ scalar_matmul = partial (torch .einsum , "agc..., gbc... -> agb..." )
17+ a = a .view (a .size (0 ), groups , - 1 , * a .shape [2 :])
18+ b = b .view (groups , - 1 , * b .shape [1 :])
1519
1620 # Compute the real and imaginary parts independently, then manually insert them
1721 # into the output Tensor. This is fairly hacky but necessary for PyTorch 1.7.0,
@@ -22,7 +26,7 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2226 c = torch .zeros (real .shape , dtype = torch .complex64 , device = a .device )
2327 c .real , c .imag = real , imag
2428
25- return c
29+ return c . view ( c . size ( 0 ), - 1 , * c . shape [ 3 :])
2630
2731
2832def to_ntuple (val : Union [int , Iterable [int ]], n : int ) -> Tuple [int , ...]:
@@ -52,6 +56,7 @@ def fft_conv(
5256 bias : Tensor = None ,
5357 padding : Union [int , Iterable [int ]] = 0 ,
5458 stride : Union [int , Iterable [int ]] = 1 ,
59+ groups : int = 1 ,
5560) -> Tensor :
5661 """Performs N-d convolution of Tensors using a fast fourier transform, which
5762 is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -78,7 +83,7 @@ def fft_conv(
7883
7984 # Because PyTorch computes a *one-sided* FFT, we need the final dimension to
8085 # have *even* length. Just pad with one more zero if the final dimension is odd.
81- if signal .size (- 1 ) % 2 :
86+ if signal .size (- 1 ) % 2 != 0 :
8287 signal_ = f .pad (signal , [0 , 1 ])
8388 else :
8489 signal_ = signal
@@ -91,11 +96,12 @@ def fft_conv(
9196 padded_kernel = f .pad (kernel , kernel_padding )
9297
9398 # Perform fourier convolution -- FFT, matrix multiply, then IFFT
99+ # signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
94100 signal_fr = rfftn (signal_ , dim = tuple (range (2 , signal .ndim )))
95101 kernel_fr = rfftn (padded_kernel , dim = tuple (range (2 , signal .ndim )))
96102
97103 kernel_fr .imag *= - 1
98- output_fr = complex_matmul (signal_fr , kernel_fr )
104+ output_fr = complex_matmul (signal_fr , kernel_fr , groups = groups )
99105 output = irfftn (output_fr , dim = tuple (range (2 , signal .ndim )))
100106
101107 # Remove extra padded values
@@ -123,6 +129,7 @@ def __init__(
123129 kernel_size : Union [int , Iterable [int ]],
124130 padding : Union [int , Iterable [int ]] = 0 ,
125131 stride : Union [int , Iterable [int ]] = 1 ,
132+ groups : int = 1 ,
126133 bias : bool = True ,
127134 ndim : int = 1 ,
128135 ):
@@ -142,10 +149,20 @@ def __init__(
142149 self .kernel_size = kernel_size
143150 self .padding = padding
144151 self .stride = stride
152+ self .groups = groups
145153 self .use_bias = bias
146154
155+ if in_channels % 2 != 0 :
156+ raise ValueError (
157+ f"'in_channels' ({ in_channels } ) must be divisible by 'groups' ({ groups } )."
158+ )
159+ if out_channels % 2 != 0 :
160+ raise ValueError (
161+ f"'out_channels' ({ out_channels } ) must be divisible by 'groups' ({ groups } )."
162+ )
163+
147164 kernel_size = to_ntuple (kernel_size , ndim )
148- self .weight = nn .Parameter (torch .randn (out_channels , in_channels , * kernel_size ))
165+ self .weight = nn .Parameter (torch .randn (out_channels , in_channels // groups , * kernel_size ))
149166 self .bias = nn .Parameter (torch .randn (out_channels ,)) if bias else None
150167
151168 def forward (self , signal ):
@@ -155,6 +172,7 @@ def forward(self, signal):
155172 bias = self .bias ,
156173 padding = self .padding ,
157174 stride = self .stride ,
175+ groups = self .groups ,
158176 )
159177
160178
0 commit comments