11from functools import partial
2+ from typing import Tuple , Union , Iterable
23
34import torch
45from torch import nn , Tensor
@@ -24,12 +25,33 @@ def complex_matmul(a: Tensor, b: Tensor) -> Tensor:
2425 return c
2526
2627
28+ def to_ntuple (val : Union [int , Iterable [int ]], n : int ) -> Tuple [int , ...]:
29+ """Casts to a tuple with length 'n'. Useful for automatically computing the
30+ padding and stride for convolutions, where users may only provide an integer.
31+
32+ Args:
33+ val: (Union[int, Iterable[int]]) Value to cast into a tuple.
34+ n: (int) Desired length of the tuple
35+
36+ Returns:
37+ (Tuple[int, ...]) Tuple of length 'n'
38+ """
39+ if isinstance (val , Iterable ):
40+ out = tuple (val )
41+ if len (out ) == n :
42+ return out
43+ else :
44+ raise ValueError (f"Cannot cast tuple of length { len (out )} to length { n } ." )
45+ else :
46+ return n * (val ,)
47+
48+
2749def fft_conv (
2850 signal : Tensor ,
2951 kernel : Tensor ,
3052 bias : Tensor = None ,
31- padding : int = 0 ,
32- stride : int = 1 ,
53+ padding : Union [ int , Iterable [ int ]] = 0 ,
54+ stride : Union [ int , Iterable [ int ]] = 1 ,
3355) -> Tensor :
3456 """Performs N-d convolution of Tensors using a fast fourier transform, which
3557 is very fast for large kernel sizes. Also, optionally adds a bias Tensor after
@@ -38,15 +60,20 @@ def fft_conv(
3860 Args:
3961 signal: (Tensor) Input tensor to be convolved with the kernel.
4062 kernel: (Tensor) Convolution kernel.
41- bias: (Optional, Tensor) Bias tensor to add to the output.
42- padding: (int) Number of zero samples to pad the input on the last dimension.
43- stride: (int) Stride size for computing output values.
63+ bias: (Tensor) Bias tensor to add to the output.
64+ padding: (Union[int, Iterable[int]) Number of zero samples to pad the
65+ input on the last dimension.
66+ stride: (Union[int, Iterable[int]) Stride size for computing output values.
4467
4568 Returns:
4669 (Tensor) Convolved tensor
4770 """
71+ # Cast padding & stride to tuples.
72+ padding_ = to_ntuple (padding , n = signal .ndim - 2 )
73+ stride_ = to_ntuple (stride , n = signal .ndim - 2 )
74+
4875 # Pad the input signal & kernel tensors
49- signal_padding = ( signal . ndim - 2 ) * [ padding , padding ]
76+ signal_padding = [ p for p in padding_ [:: - 1 ] for _ in range ( 2 ) ]
5077 signal = f .pad (signal , signal_padding )
5178 kernel_padding = [
5279 pad
@@ -65,7 +92,7 @@ def fft_conv(
6592
6693 # Remove extra padded values
6794 crop_slices = [slice (0 , output .size (0 )), slice (0 , output .size (1 ))] + [
68- slice (0 , (signal .size (i ) - kernel .size (i ) + 1 ), stride )
95+ slice (0 , (signal .size (i ) - kernel .size (i ) + 1 ), stride_ [ i - 2 ] )
6996 for i in range (2 , signal .ndim )
7097 ]
7198 output = output [crop_slices ].contiguous ()
@@ -85,18 +112,20 @@ def __init__(
85112 self ,
86113 in_channels : int ,
87114 out_channels : int ,
88- kernel_size : int ,
89- padding : int = 0 ,
90- stride : int = 1 ,
115+ kernel_size : Union [ int , Iterable [ int ]] ,
116+ padding : Union [ int , Iterable [ int ]] = 0 ,
117+ stride : Union [ int , Iterable [ int ]] = 1 ,
91118 bias : bool = True ,
119+ ndim : int = 1 ,
92120 ):
93121 """
94122 Args:
95123 in_channels: (int) Number of channels in input tensors
96124 out_channels: (int) Number of channels in output tensors
97- kernel_size: (int) Square radius of the convolution kernel
98- padding: (int) Amount of zero-padding to add to the input tensor
99- stride: (int) Stride size for computing output values
125+ kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
126+ padding: (Union[int, Iterable[int]) Number of zero samples to pad the
127+ input on the last dimension.
128+ stride: (Union[int, Iterable[int]) Stride size for computing output values.
100129 bias: (bool) If True, includes bias, which is added after convolution
101130 """
102131 super ().__init__ ()
@@ -107,7 +136,8 @@ def __init__(
107136 self .stride = stride
108137 self .use_bias = bias
109138
110- self .weight = torch .empty (0 )
139+ kernel_size = to_ntuple (kernel_size , ndim )
140+ self .weight = nn .Parameter (torch .randn (out_channels , in_channels , * kernel_size ))
111141 self .bias = nn .Parameter (torch .randn (out_channels ,)) if bias else None
112142
113143 def forward (self , signal ):
@@ -120,65 +150,6 @@ def forward(self, signal):
120150 )
121151
122152
123- class FFTConv1d (_FFTConv ):
124- def __init__ (
125- self ,
126- in_channels : int ,
127- out_channels : int ,
128- kernel_size : int ,
129- padding : int = 0 ,
130- stride : int = 1 ,
131- bias : bool = True ,
132- ):
133- super ().__init__ (
134- in_channels , out_channels , kernel_size , padding = padding , bias = bias ,
135- )
136- self .weight = nn .Parameter (torch .randn (out_channels , in_channels , kernel_size ))
137-
138-
139- class FFTConv2d (_FFTConv ):
140- def __init__ (
141- self ,
142- in_channels : int ,
143- out_channels : int ,
144- kernel_size : int ,
145- padding : int = 0 ,
146- stride : int = 1 ,
147- bias : bool = True ,
148- ):
149- super ().__init__ (
150- in_channels ,
151- out_channels ,
152- kernel_size ,
153- padding = padding ,
154- stride = stride ,
155- bias = bias ,
156- )
157- self .weight = nn .Parameter (
158- torch .randn (out_channels , in_channels , kernel_size , kernel_size )
159- )
160-
161-
162- class FFTConv3d (_FFTConv ):
163- def __init__ (
164- self ,
165- in_channels : int ,
166- out_channels : int ,
167- kernel_size : int ,
168- padding : int = 0 ,
169- stride : int = 1 ,
170- bias : bool = True ,
171- ):
172- super ().__init__ (
173- in_channels ,
174- out_channels ,
175- kernel_size ,
176- padding = padding ,
177- stride = stride ,
178- bias = bias ,
179- )
180- self .weight = nn .Parameter (
181- torch .randn (
182- out_channels , in_channels , kernel_size , kernel_size , kernel_size
183- )
184- )
153+ FFTConv1d = partial (_FFTConv , ndim = 1 )
154+ FFTConv2d = partial (_FFTConv , ndim = 2 )
155+ FFTConv3d = partial (_FFTConv , ndim = 3 )
0 commit comments