Skip to content

Commit f13c3eb

Browse files
committed
Add padding='same' for dilation=1, stride=1
1 parent 40fcd51 commit f13c3eb

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

fft_conv_pytorch/fft_conv.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as f
66
from torch import Tensor, nn
77
from torch.fft import irfftn, rfftn
8+
from math import ceil, floor
89

910

1011
def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
@@ -55,7 +56,7 @@ def fft_conv(
5556
signal: Tensor,
5657
kernel: Tensor,
5758
bias: Tensor = None,
58-
padding: Union[int, Iterable[int]] = 0,
59+
padding: Union[int, Iterable[int], str] = 0,
5960
padding_mode: str = "constant",
6061
stride: Union[int, Iterable[int]] = 1,
6162
dilation: Union[int, Iterable[int]] = 1,
@@ -69,19 +70,31 @@ def fft_conv(
6970
signal: (Tensor) Input tensor to be convolved with the kernel.
7071
kernel: (Tensor) Convolution kernel.
7172
bias: (Tensor) Bias tensor to add to the output.
72-
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
73-
input on the last dimension.
73+
padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
74+
input on the last dimension. If str, "same" supported to pad input for size preservation.
75+
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
76+
reflection not available for 3d.
7477
stride: (Union[int, Iterable[int]) Stride size for computing output values.
78+
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
79+
groups: (int) Number of groups for the convolution.
7580
7681
Returns:
7782
(Tensor) Convolved tensor
7883
"""
7984

8085
# Cast padding, stride & dilation to tuples.
8186
n = signal.ndim - 2
82-
padding_ = to_ntuple(padding, n=n)
8387
stride_ = to_ntuple(stride, n=n)
8488
dilation_ = to_ntuple(dilation, n=n)
89+
if isinstance(padding, str):
90+
if padding == "same":
91+
if stride != 1 or dilation != 1:
92+
raise ValueError("stride must be 1 for padding='same'.")
93+
padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
94+
else:
95+
raise ValueError(f"Padding mode {padding} not supported.")
96+
else:
97+
padding_ = to_ntuple(padding, n=n)
8598

8699
# internal dilation offsets
87100
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
@@ -93,8 +106,8 @@ def fft_conv(
93106
# pad the kernel internally according to the dilation parameters
94107
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]
95108

96-
# Pad the input signal & kernel tensors
97-
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
109+
# Pad the input signal & kernel tensors (round to support even sized convolutions)
110+
signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
98111
signal = f.pad(signal, signal_padding, mode=padding_mode)
99112

100113
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
@@ -155,9 +168,14 @@ def __init__(
155168
out_channels: (int) Number of channels in output tensors
156169
kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
157170
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
158-
input on the last dimension.
171+
input on the last dimension. If str, "same" supported to pad input for size preservation.
172+
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
173+
reflection not available for 3d.
159174
stride: (Union[int, Iterable[int]) Stride size for computing output values.
175+
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
176+
groups: (int) Number of groups for the convolution.
160177
bias: (bool) If True, includes bias, which is added after convolution
178+
ndim: (int) Number of dimensions of the input tensor.
161179
"""
162180
super().__init__()
163181
self.in_channels = in_channels

tests/test_functional.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
@pytest.mark.parametrize("out_channels", [2, 3])
1313
@pytest.mark.parametrize("groups", [1, 2, 3])
1414
@pytest.mark.parametrize("kernel_size", [2, 3])
15-
@pytest.mark.parametrize("padding", [0, 1])
15+
@pytest.mark.parametrize("padding", [0, 1, "same"])
1616
@pytest.mark.parametrize("stride", [1, 2])
1717
@pytest.mark.parametrize("dilation", [1, 2])
1818
@pytest.mark.parametrize("bias", [True])
@@ -30,6 +30,10 @@ def test_fft_conv_functional(
3030
ndim: int,
3131
input_size: int,
3232
):
33+
if padding == "same" and (stride != 1 or dilation != 1):
34+
# padding='same' is not compatible with strided convolutions
35+
return
36+
3337
torch_conv = getattr(f, f"conv{ndim}d")
3438
groups = _gcd(in_channels, _gcd(out_channels, groups))
3539

@@ -70,7 +74,7 @@ def test_fft_conv_functional(
7074
@pytest.mark.parametrize("out_channels", [2, 3])
7175
@pytest.mark.parametrize("groups", [1, 2, 3])
7276
@pytest.mark.parametrize("kernel_size", [2, 3])
73-
@pytest.mark.parametrize("padding", [0, 1])
77+
@pytest.mark.parametrize("padding", [0, 1, "same"])
7478
@pytest.mark.parametrize("stride", [1, 2])
7579
@pytest.mark.parametrize("dilation", [1, 2])
7680
@pytest.mark.parametrize("bias", [True])
@@ -88,6 +92,10 @@ def test_fft_conv_backward_functional(
8892
ndim: int,
8993
input_size: int,
9094
):
95+
if padding == "same" and (stride != 1 or dilation != 1):
96+
# padding='same' is not compatible with strided convolutions
97+
return
98+
9199
torch_conv = getattr(f, f"conv{ndim}d")
92100
groups = _gcd(in_channels, _gcd(out_channels, groups))
93101

tests/test_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def test_fft_conv_module(
3030
ndim: int,
3131
input_size: int,
3232
):
33+
if padding == "same" and (stride != 1 or dilation != 1):
34+
# padding='same' is not compatible with strided convolutions
35+
return
36+
3337
torch_conv = getattr(f, f"conv{ndim}d")
3438
groups = _gcd(in_channels, _gcd(out_channels, groups))
3539
fft_conv_layer = _FFTConv(
@@ -85,6 +89,10 @@ def test_fft_conv_backward_module(
8589
ndim: int,
8690
input_size: int,
8791
):
92+
if padding == "same" and (stride != 1 or dilation != 1):
93+
# padding='same' is not compatible with strided convolutions
94+
return
95+
8896
torch_conv = getattr(f, f"conv{ndim}d")
8997
groups = _gcd(in_channels, _gcd(out_channels, groups))
9098
fft_conv_layer = _FFTConv(

0 commit comments

Comments
 (0)