@@ -41,16 +41,32 @@ def benchmark_conv(
4141
4242 torch_conv = {1 : f .conv1d , 2 : f .conv2d , 3 : f .conv3d }[signal .ndim - 2 ]
4343 direct_time = benchmark (
44- torch_conv , signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups
44+ torch_conv ,
45+ signal ,
46+ kernel ,
47+ bias = bias ,
48+ padding = padding ,
49+ stride = stride ,
50+ groups = groups ,
4551 )
4652 fourier_time = benchmark (
47- fft_conv , signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups
53+ fft_conv ,
54+ signal ,
55+ kernel ,
56+ bias = bias ,
57+ padding = padding ,
58+ stride = stride ,
59+ groups = groups ,
4860 )
4961 print (f"Direct time: { direct_time } " )
5062 print (f"Fourier time: { fourier_time } " )
5163
52- y0 = torch_conv (signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups )
53- y1 = fft_conv (signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups )
64+ y0 = torch_conv (
65+ signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups
66+ )
67+ y1 = fft_conv (
68+ signal , kernel , bias = bias , padding = padding , stride = stride , groups = groups
69+ )
5470 abs_error = torch .abs (y0 - y1 )
5571 print (f"Output size: { y0 .size ()} " )
5672 print (f"Abs Error Mean: { abs_error .mean ():.3E} " )
0 commit comments