Skip to content

Commit bbe43cd

Browse files
committed
Reformat using Black.
1 parent 16acf16 commit bbe43cd

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

benchmark.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,32 @@ def benchmark_conv(
4141

4242
torch_conv = {1: f.conv1d, 2: f.conv2d, 3: f.conv3d}[signal.ndim - 2]
4343
direct_time = benchmark(
44-
torch_conv, signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
44+
torch_conv,
45+
signal,
46+
kernel,
47+
bias=bias,
48+
padding=padding,
49+
stride=stride,
50+
groups=groups,
4551
)
4652
fourier_time = benchmark(
47-
fft_conv, signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
53+
fft_conv,
54+
signal,
55+
kernel,
56+
bias=bias,
57+
padding=padding,
58+
stride=stride,
59+
groups=groups,
4860
)
4961
print(f"Direct time: {direct_time}")
5062
print(f"Fourier time: {fourier_time}")
5163

52-
y0 = torch_conv(signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups)
53-
y1 = fft_conv(signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups)
64+
y0 = torch_conv(
65+
signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
66+
)
67+
y1 = fft_conv(
68+
signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
69+
)
5470
abs_error = torch.abs(y0 - y1)
5571
print(f"Output size: {y0.size()}")
5672
print(f"Abs Error Mean: {abs_error.mean():.3E}")

0 commit comments

Comments
 (0)