@@ -29,16 +29,16 @@ testing: this would most commonly be a constructor such as `Array` or `CuArray`.
29
29
function test_fft_backend (array_constructor; test_real= true , test_inplace= true )
30
30
@testset " fft correctness" begin
31
31
# DFT along last dimension, results computed using FFTW
32
- for (_x, _fftw_fft) in (
33
- (collect (1 : 7 ),
32
+ for (_x, dims, real_input, _fftw_fft) in (
33
+ (collect (1 : 7 ), 1 , true ,
34
34
[28.0 + 0.0im ,
35
35
- 3.5 + 7.267824888003178im ,
36
36
- 3.5 + 2.7911568610884143im ,
37
37
- 3.5 + 0.7988521603655248im ,
38
38
- 3.5 - 0.7988521603655248im ,
39
39
- 3.5 - 2.7911568610884143im ,
40
40
- 3.5 - 7.267824888003178im ]),
41
- (collect (1 : 8 ),
41
+ (collect (1 : 8 ), 1 , true ,
42
42
[36.0 + 0.0im ,
43
43
- 4.0 + 9.65685424949238im ,
44
44
- 4.0 + 4.0im ,
@@ -47,21 +47,32 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
47
47
- 4.0 - 1.6568542494923806im ,
48
48
- 4.0 - 4.0im ,
49
49
- 4.0 - 9.65685424949238im ]),
50
- (collect (reshape (1 : 8 , 2 , 4 )),
50
+ (collect (reshape (1 : 8 , 2 , 4 )), 2 , true ,
51
51
[16.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ;
52
52
20.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ]),
53
- (collect (reshape (1 : 9 , 3 , 3 )),
53
+ (collect (reshape (1 : 9 , 3 , 3 )), 2 , true ,
54
54
[12.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
55
55
15.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
56
56
18.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ]),
57
+ (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 , true ,
58
+ [10.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ;;;
59
+ 26.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ]),
60
+ (collect (1 : 7 ) + im * collect (8 : 14 ), 1 , false ,
61
+ [28.0 + 77.0im ,
62
+ - 10.76782488800318 + 3.767824888003175im ,
63
+ - 6.291156861088416 - 0.7088431389115883im ,
64
+ - 4.298852160365525 - 2.7011478396344746im ,
65
+ - 2.7011478396344764 - 4.298852160365524im ,
66
+ - 0.7088431389115866 - 6.291156861088417im ,
67
+ 3.767824888003177 - 10.76782488800318im ]),
68
+ (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 , false ,
69
+ [10.0 + 42.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ;;;
70
+ 26.0 + 58.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ]),
57
71
)
58
72
x = array_constructor (_x) # dummy array that will be passed to plans
59
- x_real = float .(x) # for testing real FFTs
60
- x_complex = complex .(x_real) # for testing complex FFTs
73
+ x_complex = complex .(float .(x)) # for testing complex FFTs
61
74
fftw_fft = array_constructor (_fftw_fft)
62
75
63
- dims = ndims (x) # TODO : this is a single dimension, should check multidimensional FFTs too
64
-
65
76
# FFT
66
77
y = AbstractFFTs. fft (x_complex, dims)
67
78
@test y ≈ fftw_fft
@@ -82,7 +93,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
82
93
end
83
94
84
95
# BFFT
85
- fftw_bfft = size (x_complex, dims) .* x_complex
96
+ fftw_bfft = prod ( size (x_complex, d) for d in dims) .* x_complex
86
97
@test AbstractFFTs. bfft (y, dims) ≈ fftw_bfft
87
98
test_inplace && (@test AbstractFFTs. bfft! (copy (y), dims) ≈ fftw_bfft)
88
99
plans_to_test = [plan_bfft (similar (y), dims)]
@@ -114,16 +125,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
114
125
@test fftdims (P) == dims
115
126
end
116
127
117
- if test_real
128
+ if test_real && real_input
129
+ x_real = float .(x) # for testing real FFTs
118
130
# RFFT
119
131
fftw_rfft = fftw_fft[
120
- (Colon () for _ in 1 : (ndims (fftw_fft) - 1 )). .. ,
121
- 1 : (size (fftw_fft, ndims (fftw_fft)) ÷ 2 + 1 )
132
+ (Colon () for _ in 1 : (first (dims) - 1 )). .. ,
133
+ 1 : (size (fftw_fft, first (dims)) ÷ 2 + 1 ),
134
+ (Colon () for _ in (first (dims) + 1 ): ndims (fftw_fft)). ..
122
135
]
123
136
ry = AbstractFFTs. rfft (x_real, dims)
124
137
@test ry ≈ fftw_rfft
125
- for P in [plan_rfft (x_real, dims), inv (plan_irfft (ry, size (x, dims), dims)),
126
- AbstractFFTs. plan_inv (plan_irfft (ry, size (x, dims), dims))]
138
+ for P in [plan_rfft (x_real, dims), inv (plan_irfft (ry, size (x, first ( dims) ), dims)),
139
+ AbstractFFTs. plan_inv (plan_irfft (ry, size (x, first ( dims) ), dims))]
127
140
@test eltype (P) <: Real
128
141
@test P * x_real ≈ fftw_rfft
129
142
@test mul! (similar (ry), P, x_real) ≈ fftw_rfft
@@ -132,18 +145,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
132
145
end
133
146
134
147
# BRFFT
135
- fftw_brfft = complex . (size (x, dims) .* x_real)
136
- @test AbstractFFTs. brfft (ry, size (x_real, dims), dims) ≈ fftw_brfft
137
- P = plan_brfft (ry, size (x_real, dims), dims)
148
+ fftw_brfft = prod (size (x_real, d) for d in dims) .* x_real
149
+ @test AbstractFFTs. brfft (ry, size (x_real, first ( dims) ), dims) ≈ fftw_brfft
150
+ P = plan_brfft (ry, size (x_real, first ( dims) ), dims)
138
151
@test P * ry ≈ fftw_brfft
139
152
@test mul! (similar (x_real), P, ry) ≈ fftw_brfft
140
153
@test P \ (P * ry) ≈ ry
141
154
@test fftdims (P) == dims
142
155
143
156
# IRFFT
144
157
fftw_irfft = x_complex
145
- @test AbstractFFTs. irfft (ry, size (x, dims), dims) ≈ fftw_irfft
146
- for P in [plan_irfft (ry, size (x, dims), dims), inv (plan_rfft (x_real, dims)),
158
+ @test AbstractFFTs. irfft (ry, size (x, first ( dims) ), dims) ≈ fftw_irfft
159
+ for P in [plan_irfft (ry, size (x, first ( dims) ), dims), inv (plan_rfft (x_real, dims)),
147
160
AbstractFFTs. plan_inv (plan_rfft (x_real, dims))]
148
161
@test P * ry ≈ fftw_irfft
149
162
@test mul! (similar (x_real), P, ry) ≈ fftw_irfft
0 commit comments