@@ -71,79 +71,3 @@ def test_fft_conv(
7171 _assert_almost_equal (y0 , y1 )
7272 _assert_almost_equal (y0 , y2 )
7373 _assert_almost_equal (y1 , y2 )
74-
75-
76- # def benchmark_conv(
77- # signal: Tensor,
78- # kernel: Tensor,
79- # bias: Tensor,
80- # padding: Union[int, Iterable[int]] = 0,
81- # stride: Union[int, Iterable[int]] = 1,
82- # groups: int = 1,
83- # ):
84- # print(f"Signal size: {signal.shape}")
85- # print(f"Kernel size: {kernel.shape}")
86-
87- # torch_conv = {1: f.conv1d, 2: f.conv2d, 3: f.conv3d}[signal.ndim - 2]
88- # direct_time = benchmark(
89- # torch_conv,
90- # signal,
91- # kernel,
92- # bias=bias,
93- # padding=padding,
94- # stride=stride,
95- # groups=groups,
96- # )
97- # fourier_time = benchmark(
98- # fft_conv,
99- # signal,
100- # kernel,
101- # bias=bias,
102- # padding=padding,
103- # stride=stride,
104- # groups=groups,
105- # )
106- # print(f"Direct time: {direct_time}")
107- # print(f"Fourier time: {fourier_time}")
108-
109- # y0 = torch_conv(
110- # signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
111- # )
112- # y1 = fft_conv(
113- # signal, kernel, bias=bias, padding=padding, stride=stride, groups=groups
114- # )
115- # abs_error = torch.abs(y0 - y1)
116- # print(f"Output size: {y0.size()}")
117- # print(f"Abs Error Mean: {abs_error.mean():.3E}")
118- # print(f"Abs Error Std Dev: {abs_error.std():.3E}")
119-
120-
121- # print("\n--- 1D Convolution ---")
122- # benchmark_conv(
123- # signal=torch.randn(4, 4, 4091),
124- # kernel=torch.randn(6, 2, 1025),
125- # bias=torch.randn(6),
126- # padding=512,
127- # stride=3,
128- # groups=2,
129- # )
130-
131- # print("\n--- 2D Convolution ---")
132- # benchmark_conv(
133- # signal=torch.randn(4, 4, 256, 235),
134- # kernel=torch.randn(6, 2, 22, 21),
135- # bias=torch.randn(6),
136- # padding=(9, 10),
137- # stride=(2, 3),
138- # groups=2,
139- # )
140-
141- # print("\n--- 3D Convolution ---")
142- # benchmark_conv(
143- # signal=torch.randn(4, 4, 96, 72, 61),
144- # kernel=torch.randn(6, 2, 12, 7, 9),
145- # bias=torch.randn(6),
146- # padding=(2, 3, 4),
147- # stride=(1, 2, 3),
148- # groups=2,
149- # )
0 commit comments