Skip to content

Commit 833f271

Browse files
authored
Merge pull request #15 from fkodom/add-benchmarks
Add Benchmark Plots
2 parents 8286f76 + 78d9a9b commit 833f271

File tree

4 files changed

+293
-6
lines changed

4 files changed

+293
-6
lines changed

.gitignore

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,132 @@
1-
__pycache__
2-
.idea
3-
.vscode
4-
.mypy_cache
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
client_secrets.json
7+
8+
# C extensions
9+
*.so
10+
11+
# Distribution / packaging
12+
.Python
13+
build/
14+
develop-eggs/
15+
dist/
16+
downloads/
17+
eggs/
18+
.eggs/
19+
lib/
20+
lib64/
21+
parts/
22+
sdist/
23+
var/
24+
wheels/
25+
pip-wheel-metadata/
26+
share/python-wheels/
27+
*.egg-info/
28+
.installed.cfg
29+
*.egg
30+
MANIFEST
31+
32+
# PyInstaller
33+
# Usually these files are written by a python script from a template
34+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
35+
*.manifest
36+
*.spec
37+
38+
# Installer logs
39+
pip-log.txt
40+
pip-delete-this-directory.txt
41+
42+
# Unit test / coverage reports
43+
htmlcov/
44+
.tox/
45+
.nox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
*.py,cover
53+
.hypothesis/
54+
.pytest_cache/
55+
56+
# Translations
57+
*.mo
58+
*.pot
59+
60+
# Django stuff:
61+
*.log
62+
local_settings.py
63+
db.sqlite3
64+
db.sqlite3-journal
65+
66+
# Flask stuff:
67+
instance/
68+
.webassets-cache
69+
70+
# Scrapy stuff:
71+
.scrapy
72+
73+
# Sphinx documentation
74+
docs/_build/
75+
76+
# PyBuilder
77+
target/
78+
79+
# Jupyter Notebook
580
.ipynb_checkpoints
81+
82+
# IPython
83+
profile_default/
84+
ipython_config.py
85+
86+
# pyenv
87+
.python-version
88+
89+
# pipenv
90+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
92+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
93+
# install all needed dependencies.
94+
#Pipfile.lock
95+
96+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
97+
__pypackages__/
98+
99+
# Celery stuff
100+
celerybeat-schedule
101+
celerybeat.pid
102+
103+
# SageMath parsed files
104+
*.sage.py
105+
106+
# Environments
107+
.env
108+
.venv
109+
env/
110+
venv/
111+
ENV/
112+
env.bak/
113+
venv.bak/
114+
115+
# Spyder project settings
116+
.spyderproject
117+
.spyproject
118+
119+
# Rope project settings
120+
.ropeproject
121+
122+
# mkdocs documentation
123+
/site
124+
125+
# mypy
126+
.mypy_cache/
127+
.dmypy.json
128+
dmypy.json
129+
130+
# Pyre type checker
131+
.pyre/
132+

README.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch.
55
* **Much slower** than direct convolution for small kernels.
66
* In my local tests, FFT convolution is faster when the kernel has >100 or so elements.
77
* Dependent on machine and PyTorch version.
8+
* Also see benchmarks below.
89

910

1011
## Install
@@ -21,7 +22,7 @@ cd fft-conv-pytorch
2122
pip install .
2223
```
2324

24-
### Example Usage
25+
## Example Usage
2526

2627
```python
2728
import torch
@@ -45,4 +46,18 @@ fft_conv = FFTConv1d(3, 2, 128, bias=True)
4546
fft_conv.weight = torch.nn.Parameter(kernel)
4647
fft_conv.bias = torch.nn.Parameter(bias)
4748
out = fft_conv(signal)
48-
```
49+
```
50+
51+
## Benchmarks
52+
53+
Benchmarking FFT convolution against the direct convolution from PyTorch in 1D, 2D,
54+
and 3D. The exact times are heavily dependent on your local machine, but relative
55+
scaling with kernel size is always the same.
56+
57+
Dimensions | Input Size | Input Channels | Output Channels | Bias | Padding | Stride | Dilation
58+
-----------|--------------|----------------|-----------------|------|---------|--------|---------
59+
1 | (4096) | 4 | 4 | True | 0 | 1 | 1
60+
2 | (512, 512) | 4 | 4 | True | 0 | 1 | 1
61+
3 | (64, 64, 64) | 4 | 4 | True | 0 | 1 | 1
62+
63+
![Benchmark Plot](doc/benchmark.png)

doc/benchmark.png

51 KB
Loading
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from functools import lru_cache, partial
2+
from timeit import Timer
3+
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as f
9+
from tqdm import tqdm
10+
11+
from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple
12+
13+
14+
class Benchmark(NamedTuple):
15+
mean: float
16+
std: float
17+
18+
def __repr__(self):
19+
return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
20+
21+
def __str__(self):
22+
return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
23+
24+
25+
def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark:
26+
timer = Timer(
27+
"fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs},
28+
)
29+
times = timer.repeat(number=1, repeat=num_iterations + 1)
30+
return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item())
31+
32+
33+
@lru_cache(maxsize=1)
34+
def _get_conv_inputs(
35+
ndim: int,
36+
input_size: int,
37+
kernel_size: Union[int, Iterable[int]],
38+
batch_size: int = 2,
39+
in_channels: int = 8,
40+
out_channels: int = 8,
41+
):
42+
dims = ndim * [input_size]
43+
signal = torch.randn(batch_size, in_channels, *dims)
44+
45+
kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2)
46+
weight = torch.randn(out_channels, in_channels, *kernel_size, requires_grad=True)
47+
bias = torch.randn(out_channels, requires_grad=True)
48+
49+
return signal, weight, bias
50+
51+
52+
def benchmark_conv(
53+
ndim: int,
54+
input_size: int,
55+
kernel_size: int,
56+
fft: bool = True,
57+
num_iterations: int = 10,
58+
):
59+
conv_fn = fft_conv if fft else getattr(f, f"conv{ndim}d")
60+
signal, weight, bias = _get_conv_inputs(
61+
ndim=ndim, input_size=input_size, kernel_size=kernel_size
62+
)
63+
return benchmark(conv_fn, signal, weight, bias=bias, num_iterations=num_iterations)
64+
65+
66+
def benchmark_kernel_size(
67+
kernel_sizes: Sequence[int],
68+
ndim: int,
69+
input_size: int,
70+
fft: bool = True,
71+
num_iterations: int = 10,
72+
desc: str = "",
73+
) -> List[Benchmark]:
74+
fn = partial(
75+
benchmark_conv,
76+
ndim=ndim,
77+
input_size=input_size,
78+
fft=fft,
79+
num_iterations=num_iterations,
80+
)
81+
return [fn(kernel_size=k) for k in tqdm(kernel_sizes, desc=desc)]
82+
83+
84+
def _plot_benchmarks(
85+
benchmarks: List[Benchmark],
86+
config: Dict,
87+
ax: plt.Axes,
88+
color: str,
89+
label: Optional[str] = None,
90+
):
91+
xs = config["kernel_sizes"]
92+
ys = np.array([b.mean * 1000 for b in benchmarks])
93+
std = np.array([b.std * 1000 for b in benchmarks])
94+
ax.plot(xs, ys, color, label=label)
95+
ax.fill_between(
96+
xs, ys - std, ys + std, facecolor=color, alpha=0.25, label="_nolegend_"
97+
)
98+
99+
ndim = config["ndim"]
100+
ax.set_title(f"{ndim}D")
101+
kernel_size_str = "(" + " x ".join(["n"] * ndim) + ")"
102+
ax.set_xlabel(f"Kernel Size {kernel_size_str}")
103+
104+
105+
if __name__ == "__main__":
106+
import os
107+
108+
configs = [
109+
{
110+
"ndim": 1,
111+
"input_size": 4096,
112+
"num_iterations": 256,
113+
"kernel_sizes": np.arange(64, 513, 64),
114+
},
115+
{
116+
"ndim": 2,
117+
"input_size": 512,
118+
"num_iterations": 16,
119+
"kernel_sizes": np.arange(4, 49, 6),
120+
},
121+
{
122+
"ndim": 3,
123+
"input_size": 64,
124+
"num_iterations": 16,
125+
"kernel_sizes": np.arange(2, 17, 2),
126+
},
127+
]
128+
129+
save_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
130+
fix, ax = plt.subplots(
131+
1, len(configs), figsize=(4 * len(configs), 4), squeeze=False
132+
)
133+
134+
for i, config in enumerate(configs):
135+
fft = benchmark_kernel_size(fft=True, **config, desc=f"FFT {config['ndim']}D")
136+
_plot_benchmarks(fft, config=config, ax=ax[0, i], color="r", label="FFT")
137+
138+
direct = benchmark_kernel_size(
139+
fft=False, **config, desc=f"Direct {config['ndim']}D"
140+
)
141+
_plot_benchmarks(direct, config=config, ax=ax[0, i], color="b", label="Direct")
142+
143+
ax[0, 0].set_ylabel("Execution Time (ms)")
144+
plt.legend(["FFT", "Direct"])
145+
plt.savefig(os.path.join(save_dir, "benchmark.png"))

0 commit comments

Comments
 (0)