Skip to content

Commit 71ef0c4

Browse files
committed
Fixed bug for inputs with odd length in the final dimension.
1 parent 4abff90 commit 71ef0c4

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

benchmark.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, NamedTuple
1+
from typing import Callable, NamedTuple, Union, Iterable
22
from timeit import Timer
33

44
import torch
@@ -29,7 +29,11 @@ def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchm
2929

3030

3131
def benchmark_conv(
32-
signal: Tensor, kernel: Tensor, bias: Tensor, padding: int = 0, stride: int = 1
32+
signal: Tensor,
33+
kernel: Tensor,
34+
bias: Tensor,
35+
padding: Union[int, Iterable[int]] = 0,
36+
stride: Union[int, Iterable[int]] = 1,
3337
):
3438
print(f"Signal size: {signal.shape}")
3539
print(f"Kernel size: {kernel.shape}")
@@ -54,24 +58,27 @@ def benchmark_conv(
5458

5559
print("\n--- 1D Convolution ---")
5660
benchmark_conv(
57-
signal=torch.randn(3, 3, 4096),
61+
signal=torch.randn(3, 3, 4091),
5862
kernel=torch.randn(2, 3, 1025),
5963
bias=torch.randn(2),
6064
padding=512,
65+
stride=3,
6166
)
6267

6368
print("\n--- 2D Convolution ---")
6469
benchmark_conv(
65-
signal=torch.randn(3, 3, 256, 256),
66-
kernel=torch.randn(2, 3, 21, 21),
70+
signal=torch.randn(3, 3, 256, 235),
71+
kernel=torch.randn(2, 3, 19, 21),
6772
bias=torch.randn(2),
68-
padding=10,
73+
padding=(9, 10),
74+
stride=(2, 3),
6975
)
7076

7177
print("\n--- 3D Convolution ---")
7278
benchmark_conv(
73-
signal=torch.randn(3, 3, 64, 64, 64),
74-
kernel=torch.randn(2, 3, 9, 9, 9),
79+
signal=torch.randn(3, 3, 64, 72, 61),
80+
kernel=torch.randn(2, 3, 5, 7, 9),
7581
bias=torch.randn(2),
76-
padding=4,
82+
padding=(2, 3, 4),
83+
stride=(1, 2, 3)
7784
)

fft_conv.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,23 @@ def fft_conv(
7575
# Pad the input signal & kernel tensors
7676
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
7777
signal = f.pad(signal, signal_padding)
78+
79+
# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
80+
# have *even* length. Just pad with one more zero if the final dimension is odd.
81+
if signal.size(-1) % 2:
82+
signal_ = f.pad(signal, [0, 1])
83+
else:
84+
signal_ = signal
85+
7886
kernel_padding = [
7987
pad
80-
for i in reversed(range(2, signal.ndim))
81-
for pad in [0, signal.size(i) - kernel.size(i)]
88+
for i in reversed(range(2, signal_.ndim))
89+
for pad in [0, signal_.size(i) - kernel.size(i)]
8290
]
8391
padded_kernel = f.pad(kernel, kernel_padding)
8492

8593
# Perform fourier convolution -- FFT, matrix multiply, then IFFT
86-
signal_fr = rfftn(signal, dim=tuple(range(2, signal.ndim)))
94+
signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
8795
kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))
8896

8997
kernel_fr.imag *= -1

0 commit comments

Comments
 (0)