44import torch
55import torch .nn .functional as f
66
7- from fft_conv_pytorch .fft_conv import _FFTConv , fft_conv , to_ntuple
7+ from fft_conv_pytorch .fft_conv import fft_conv , to_ntuple
88from tests .utils import _assert_almost_equal , _gcd
99
1010
11-
12- @pytest .mark .parametrize ("in_channels" , [1 , 2 , 3 ])
13- @pytest .mark .parametrize ("out_channels" , [1 , 2 , 3 ])
11+ @pytest .mark .parametrize ("in_channels" , [2 , 3 ])
12+ @pytest .mark .parametrize ("out_channels" , [2 , 3 ])
1413@pytest .mark .parametrize ("groups" , [1 , 2 , 3 ])
15- @pytest .mark .parametrize ("kernel_size" , [1 , 2 , 3 ])
14+ @pytest .mark .parametrize ("kernel_size" , [2 , 3 ])
1615@pytest .mark .parametrize ("padding" , [0 , 1 ])
17- @pytest .mark .parametrize ("stride" , [1 , 2 , 3 ])
18- @pytest .mark .parametrize ("dilation" , [1 , 2 , 3 ])
19- @pytest .mark .parametrize ("bias" , [True , False ])
16+ @pytest .mark .parametrize ("stride" , [1 , 2 ])
17+ @pytest .mark .parametrize ("dilation" , [1 , 2 ])
18+ @pytest .mark .parametrize ("bias" , [True ])
2019@pytest .mark .parametrize ("ndim" , [1 , 2 , 3 ])
2120@pytest .mark .parametrize ("input_size" , [7 , 8 ])
2221def test_fft_conv_functional (
@@ -46,34 +45,30 @@ def test_fft_conv_functional(
4645 )
4746
4847 kernel_size = to_ntuple (kernel_size , n = signal .ndim - 2 )
49- w0 = torch .randn (out_channels , in_channels // groups , * kernel_size ,
50- requires_grad = True )
48+ w0 = torch .randn (
49+ out_channels , in_channels // groups , * kernel_size , requires_grad = True
50+ )
5151 w1 = w0 .detach ().clone ().requires_grad_ ()
5252
5353 b0 = torch .randn (out_channels , requires_grad = True ) if bias else None
5454 b1 = b0 .detach ().clone ().requires_grad_ () if bias else None
5555
56- kwargs = dict (
57- padding = padding ,
58- stride = stride ,
59- dilation = dilation ,
60- groups = groups ,
61- )
56+ kwargs = dict (padding = padding , stride = stride , dilation = dilation , groups = groups ,)
6257
6358 y0 = fft_conv (signal , w0 , bias = b0 , ** kwargs )
6459 y1 = torch_conv (signal , w1 , bias = b1 , ** kwargs )
65-
60+
6661 _assert_almost_equal (y0 , y1 )
6762
6863
69- @pytest .mark .parametrize ("in_channels" , [1 , 2 , 3 ])
70- @pytest .mark .parametrize ("out_channels" , [1 , 2 , 3 ])
64+ @pytest .mark .parametrize ("in_channels" , [2 , 3 ])
65+ @pytest .mark .parametrize ("out_channels" , [2 , 3 ])
7166@pytest .mark .parametrize ("groups" , [1 , 2 , 3 ])
72- @pytest .mark .parametrize ("kernel_size" , [1 , 2 , 3 ])
67+ @pytest .mark .parametrize ("kernel_size" , [2 , 3 ])
7368@pytest .mark .parametrize ("padding" , [0 , 1 ])
74- @pytest .mark .parametrize ("stride" , [1 , 2 , 3 ])
75- @pytest .mark .parametrize ("dilation" , [1 , 2 , 3 ])
76- @pytest .mark .parametrize ("bias" , [True , False ])
69+ @pytest .mark .parametrize ("stride" , [1 , 2 ])
70+ @pytest .mark .parametrize ("dilation" , [1 , 2 ])
71+ @pytest .mark .parametrize ("bias" , [True ])
7772@pytest .mark .parametrize ("ndim" , [1 , 2 , 3 ])
7873@pytest .mark .parametrize ("input_size" , [7 , 8 ])
7974def test_fft_conv_backward_functional (
@@ -96,28 +91,24 @@ def test_fft_conv_backward_functional(
9691 signal = torch .randn (batch_size , in_channels , * dims )
9792
9893 kernel_size = to_ntuple (kernel_size , n = signal .ndim - 2 )
99- w0 = torch .randn (out_channels , in_channels // groups , * kernel_size ,
100- requires_grad = True )
94+ w0 = torch .randn (
95+ out_channels , in_channels // groups , * kernel_size , requires_grad = True
96+ )
10197 w1 = w0 .detach ().clone ().requires_grad_ ()
102-
98+
10399 b0 = torch .randn (out_channels , requires_grad = True ) if bias else None
104100 b1 = b0 .detach ().clone ().requires_grad_ () if bias else None
105101
106- kwargs = dict (
107- padding = padding ,
108- stride = stride ,
109- dilation = dilation ,
110- groups = groups ,
111- )
102+ kwargs = dict (padding = padding , stride = stride , dilation = dilation , groups = groups ,)
112103
113104 y0 = fft_conv (signal , w0 , bias = b0 , ** kwargs )
114105 y1 = torch_conv (signal , w1 , bias = b1 , ** kwargs )
115-
106+
116107 # Compute pseudo-loss and gradient
117108 y0 .sum ().backward ()
118109 y1 .sum ().backward ()
119-
110+
120111 _assert_almost_equal (w0 .grad , w1 .grad )
121112
122- if bias :
113+ if bias :
123114 _assert_almost_equal (b0 .grad , b1 .grad )
0 commit comments