Skip to content

Commit 9718c5c

Browse files
committed
Include benchmark figure in README, add scripts for generating the figure
1 parent e6dc2cf commit 9718c5c

File tree

3 files changed

+161
-2
lines changed

3 files changed

+161
-2
lines changed

README.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
55
* **Much slower** than direct convolution for small kernels.
66
* In my local tests, FFT convolution is faster when the kernel has >100 or so elements.
77
* Dependent on machine and PyTorch version.
8+
* Also see benchmarks below.
89

910

1011
## Install
@@ -21,7 +22,7 @@ cd fft-conv-pytorch
2122
pip install .
2223
```
2324

24-
### Example Usage
25+
## Example Usage
2526

2627
```python
2728
import torch
@@ -45,4 +46,18 @@ fft_conv = FFTConv1d(3, 2, 128, bias=True)
4546
fft_conv.weight = torch.nn.Parameter(kernel)
4647
fft_conv.bias = torch.nn.Parameter(bias)
4748
out = fft_conv(signal)
48-
```
49+
```
50+
51+
## Benchmarks
52+
53+
Benchmarking FFT convolution against the direct convolution from PyTorch in 1D, 2D,
54+
and 3D. The exact times are heavily dependent on your local machine, but relative
55+
scaling with kernel size is always the same.
56+
57+
Num dimensions | Input Array Size
58+
---------------|------------------
59+
1 | (4096)
60+
2 | (512, 512)
61+
3 | (64, 64, 64)
62+
63+
![Benchmark Plot](doc/benchmark.png)

doc/benchmark.png

49 KB
Loading
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from functools import lru_cache, partial
2+
from timeit import Timer
3+
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as f
9+
from tqdm import tqdm
10+
11+
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
12+
13+
14+
class Benchmark(NamedTuple):
15+
mean: float
16+
std: float
17+
18+
def __repr__(self):
19+
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
20+
21+
def __str__(self):
22+
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
23+
24+
25+
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
26+
timer = Timer(
27+
"fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs},
28+
)
29+
times = timer.repeat(number=1, repeat=num_iterations + 1)
30+
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
31+
32+
33+
@lru_cache(maxsize=1)
34+
def _get_conv_inputs(
35+
ndim: int,
36+
input_size: int,
37+
kernel_size: Union[int, Iterable[int]],
38+
batch_size: int = 2,
39+
in_channels: int = 8,
40+
out_channels: int = 8,
41+
):
42+
dims = ndim * [input_size]
43+
signal = torch.randn(batch_size, in_channels, *dims)
44+
45+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
46+
weight = torch.randn(out_channels, in_channels, *kernel_size, requires_grad=True)
47+
bias = torch.randn(out_channels, requires_grad=True)
48+
49+
return signal, weight, bias
50+
51+
52+
def benchmark_conv(
53+
ndim: int,
54+
input_size: int,
55+
kernel_size: int,
56+
fft: bool = True,
57+
num_iterations: int = 10,
58+
):
59+
conv_fn = fft_conv if fft else getattr(f, f"conv{ndim}d")
60+
signal, weight, bias = _get_conv_inputs(
61+
ndim=ndim, input_size=input_size, kernel_size=kernel_size
62+
)
63+
return benchmark(conv_fn, signal, weight, bias=bias, num_iterations=num_iterations)
64+
65+
66+
def benchmark_kernel_size(
67+
kernel_sizes: Sequence[int],
68+
ndim: int,
69+
input_size: int,
70+
fft: bool = True,
71+
num_iterations: int = 10,
72+
desc: str = "",
73+
) -> List[Benchmark]:
74+
fn = partial(
75+
benchmark_conv,
76+
ndim=ndim,
77+
input_size=input_size,
78+
fft=fft,
79+
num_iterations=num_iterations,
80+
)
81+
return [fn(kernel_size=k) for k in tqdm(kernel_sizes, desc=desc)]
82+
83+
84+
def _plot_benchmarks(
85+
benchmarks: List[Benchmark],
86+
config: Dict,
87+
ax: plt.Axes,
88+
color: str,
89+
label: Optional[str] = None,
90+
):
91+
xs = config["kernel_sizes"]
92+
ys = np.array([b.mean * 1000 for b in benchmarks])
93+
std = np.array([b.std * 1000 for b in benchmarks])
94+
ax.plot(xs, ys, color, label=label)
95+
ax.fill_between(
96+
xs, ys - std, ys + std, facecolor=color, alpha=0.25, label="_nolegend_"
97+
)
98+
99+
ndim = config["ndim"]
100+
ax.set_title(f"{ndim}D")
101+
kernel_size_str = "(" + " x ".join(["n"] * ndim) + ")"
102+
ax.set_xlabel(f"Kernel Size {kernel_size_str}")
103+
104+
105+
if __name__ == "__main__":
106+
import os
107+
108+
configs = [
109+
{
110+
"ndim": 1,
111+
"input_size": 4096,
112+
"num_iterations": 256,
113+
"kernel_sizes": np.arange(64, 513, 64),
114+
},
115+
{
116+
"ndim": 2,
117+
"input_size": 512,
118+
"num_iterations": 16,
119+
"kernel_sizes": np.arange(4, 49, 6),
120+
},
121+
{
122+
"ndim": 3,
123+
"input_size": 64,
124+
"num_iterations": 16,
125+
"kernel_sizes": np.arange(2, 17, 2),
126+
},
127+
]
128+
129+
save_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
130+
fix, ax = plt.subplots(
131+
1, len(configs), figsize=(4 * len(configs), 4), squeeze=False
132+
)
133+
134+
for i, config in enumerate(configs):
135+
fft = benchmark_kernel_size(fft=True, **config, desc=f"FFT {config['ndim']}D")
136+
_plot_benchmarks(fft, config=config, ax=ax[0, i], color="r", label="FFT")
137+
138+
direct = benchmark_kernel_size(
139+
fft=False, **config, desc=f"Direct {config['ndim']}D"
140+
)
141+
_plot_benchmarks(direct, config=config, ax=ax[0, i], color="b", label="Direct")
142+
143+
ax[0, 0].set_ylabel("Execution Time (ms)")
144+
plt.savefig(os.path.join(save_dir, "benchmark.png"))

0 commit comments

Comments
 (0)