|
| 1 | +from functools import lru_cache, partial |
| 2 | +from timeit import Timer |
| 3 | +from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Union |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.nn.functional as f |
| 9 | +from tqdm import tqdm |
| 10 | + |
| 11 | +from fft_conv_pytorch.fft_conv import fft_conv, to_ntuple |
| 12 | + |
| 13 | + |
| 14 | +class Benchmark(NamedTuple): |
| 15 | + mean: float |
| 16 | + std: float |
| 17 | + |
| 18 | + def __repr__(self): |
| 19 | + return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})" |
| 20 | + |
| 21 | + def __str__(self): |
| 22 | + return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s" |
| 23 | + |
| 24 | + |
| 25 | +def benchmark(fn: Callable, *args, num_iterations: int = 10, **kwargs) -> Benchmark: |
| 26 | + timer = Timer( |
| 27 | + "fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs}, |
| 28 | + ) |
| 29 | + times = timer.repeat(number=1, repeat=num_iterations + 1) |
| 30 | + return Benchmark(np.mean(times[1:]).item(), np.std(times[1:]).item()) |
| 31 | + |
| 32 | + |
| 33 | +@lru_cache(maxsize=1) |
| 34 | +def _get_conv_inputs( |
| 35 | + ndim: int, |
| 36 | + input_size: int, |
| 37 | + kernel_size: Union[int, Iterable[int]], |
| 38 | + batch_size: int = 2, |
| 39 | + in_channels: int = 8, |
| 40 | + out_channels: int = 8, |
| 41 | +): |
| 42 | + dims = ndim * [input_size] |
| 43 | + signal = torch.randn(batch_size, in_channels, *dims) |
| 44 | + |
| 45 | + kernel_size = to_ntuple(kernel_size, n=signal.ndim - 2) |
| 46 | + weight = torch.randn(out_channels, in_channels, *kernel_size, requires_grad=True) |
| 47 | + bias = torch.randn(out_channels, requires_grad=True) |
| 48 | + |
| 49 | + return signal, weight, bias |
| 50 | + |
| 51 | + |
| 52 | +def benchmark_conv( |
| 53 | + ndim: int, |
| 54 | + input_size: int, |
| 55 | + kernel_size: int, |
| 56 | + fft: bool = True, |
| 57 | + num_iterations: int = 10, |
| 58 | +): |
| 59 | + conv_fn = fft_conv if fft else getattr(f, f"conv{ndim}d") |
| 60 | + signal, weight, bias = _get_conv_inputs( |
| 61 | + ndim=ndim, input_size=input_size, kernel_size=kernel_size |
| 62 | + ) |
| 63 | + return benchmark(conv_fn, signal, weight, bias=bias, num_iterations=num_iterations) |
| 64 | + |
| 65 | + |
| 66 | +def benchmark_kernel_size( |
| 67 | + kernel_sizes: Sequence[int], |
| 68 | + ndim: int, |
| 69 | + input_size: int, |
| 70 | + fft: bool = True, |
| 71 | + num_iterations: int = 10, |
| 72 | + desc: str = "", |
| 73 | +) -> List[Benchmark]: |
| 74 | + fn = partial( |
| 75 | + benchmark_conv, |
| 76 | + ndim=ndim, |
| 77 | + input_size=input_size, |
| 78 | + fft=fft, |
| 79 | + num_iterations=num_iterations, |
| 80 | + ) |
| 81 | + return [fn(kernel_size=k) for k in tqdm(kernel_sizes, desc=desc)] |
| 82 | + |
| 83 | + |
| 84 | +def _plot_benchmarks( |
| 85 | + benchmarks: List[Benchmark], |
| 86 | + config: Dict, |
| 87 | + ax: plt.Axes, |
| 88 | + color: str, |
| 89 | + label: Optional[str] = None, |
| 90 | +): |
| 91 | + xs = config["kernel_sizes"] |
| 92 | + ys = np.array([b.mean * 1000 for b in benchmarks]) |
| 93 | + std = np.array([b.std * 1000 for b in benchmarks]) |
| 94 | + ax.plot(xs, ys, color, label=label) |
| 95 | + ax.fill_between( |
| 96 | + xs, ys - std, ys + std, facecolor=color, alpha=0.25, label="_nolegend_" |
| 97 | + ) |
| 98 | + |
| 99 | + ndim = config["ndim"] |
| 100 | + ax.set_title(f"{ndim}D") |
| 101 | + kernel_size_str = "(" + " x ".join(["n"] * ndim) + ")" |
| 102 | + ax.set_xlabel(f"Kernel Size {kernel_size_str}") |
| 103 | + |
| 104 | + |
| 105 | +if __name__ == "__main__": |
| 106 | + import os |
| 107 | + |
| 108 | + configs = [ |
| 109 | + { |
| 110 | + "ndim": 1, |
| 111 | + "input_size": 4096, |
| 112 | + "num_iterations": 256, |
| 113 | + "kernel_sizes": np.arange(64, 513, 64), |
| 114 | + }, |
| 115 | + { |
| 116 | + "ndim": 2, |
| 117 | + "input_size": 512, |
| 118 | + "num_iterations": 16, |
| 119 | + "kernel_sizes": np.arange(4, 49, 6), |
| 120 | + }, |
| 121 | + { |
| 122 | + "ndim": 3, |
| 123 | + "input_size": 64, |
| 124 | + "num_iterations": 16, |
| 125 | + "kernel_sizes": np.arange(2, 17, 2), |
| 126 | + }, |
| 127 | + ] |
| 128 | + |
| 129 | + save_dir = os.path.join(os.path.dirname(__file__), os.path.pardir) |
| 130 | + fix, ax = plt.subplots( |
| 131 | + 1, len(configs), figsize=(4 * len(configs), 4), squeeze=False |
| 132 | + ) |
| 133 | + |
| 134 | + for i, config in enumerate(configs): |
| 135 | + fft = benchmark_kernel_size(fft=True, **config, desc=f"FFT {config['ndim']}D") |
| 136 | + _plot_benchmarks(fft, config=config, ax=ax[0, i], color="r", label="FFT") |
| 137 | + |
| 138 | + direct = benchmark_kernel_size( |
| 139 | + fft=False, **config, desc=f"Direct {config['ndim']}D" |
| 140 | + ) |
| 141 | + _plot_benchmarks(direct, config=config, ax=ax[0, i], color="b", label="Direct") |
| 142 | + |
| 143 | + ax[0, 0].set_ylabel("Execution Time (ms)") |
| 144 | + plt.savefig(os.path.join(save_dir, "benchmark.png")) |
0 commit comments