@@ -11,34 +11,33 @@ using LinearAlgebra
11
11
using Test
12
12
13
13
"""
14
- test_fft_backend(array_constructor ; test_real=true, test_inplace=true)
14
+ test_fft_backend(ArrayType=Array ; test_real=true, test_inplace=true)
15
15
16
16
Run tests to verify correctness of FFT functions using a particular
17
17
backend plan implementation. The backend implementation is assumed to be loaded
18
18
prior to calling this function.
19
19
20
20
# Arguments
21
21
22
- - `array_constructor`: determines the `AbstractArray` implementation for
23
- which the correctness tests are run. It is assumed to be a callable object that
24
- takes in input arrays of type `Array` and return arrays of the desired type for
25
- testing. For example, this can be a constructor such as `Array` or `CUDA.CuArray`.
22
+ - `ArrayType`: determines the `AbstractArray` implementation for
23
+ which the correctness tests are run. Arrays are constructed via
24
+ `convert(ArrayType, ...)`.
26
25
- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs.
27
26
- `test_inplace=true`: whether to test in-place plans.
28
27
"""
29
- function test_fft_backend (array_constructor ; test_real= true , test_inplace= true )
28
+ function test_fft_backend (ArrayType = Array ; test_real= true , test_inplace= true )
30
29
@testset " fft correctness" begin
31
30
# DFT along last dimension, results computed using FFTW
32
- for (_x, dims, real_input, _fftw_fft) in (
33
- (collect (1 : 7 ), 1 , true ,
31
+ for (_x, dims, _fftw_fft) in (
32
+ (collect (1 : 7 ), 1 ,
34
33
[28.0 + 0.0im ,
35
34
- 3.5 + 7.267824888003178im ,
36
35
- 3.5 + 2.7911568610884143im ,
37
36
- 3.5 + 0.7988521603655248im ,
38
37
- 3.5 - 0.7988521603655248im ,
39
38
- 3.5 - 2.7911568610884143im ,
40
39
- 3.5 - 7.267824888003178im ]),
41
- (collect (1 : 8 ), 1 , true ,
40
+ (collect (1 : 8 ), 1 ,
42
41
[36.0 + 0.0im ,
43
42
- 4.0 + 9.65685424949238im ,
44
43
- 4.0 + 4.0im ,
@@ -47,49 +46,50 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
47
46
- 4.0 - 1.6568542494923806im ,
48
47
- 4.0 - 4.0im ,
49
48
- 4.0 - 9.65685424949238im ]),
50
- (collect (reshape (1 : 8 , 2 , 4 )), 2 , true ,
49
+ (collect (reshape (1 : 8 , 2 , 4 )), 2 ,
51
50
[16.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ;
52
51
20.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ]),
53
- (collect (reshape (1 : 9 , 3 , 3 )), 2 , true ,
52
+ (collect (reshape (1 : 9 , 3 , 3 )), 2 ,
54
53
[12.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
55
54
15.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
56
55
18.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ]),
57
- (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 , true ,
56
+ (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 ,
58
57
cat ([10.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ],
59
58
[26.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ],
60
59
dims= 3 )),
61
- (collect (1 : 7 ) + im * collect (8 : 14 ), 1 , false ,
60
+ (collect (1 : 7 ) + im * collect (8 : 14 ), 1 ,
62
61
[28.0 + 77.0im ,
63
62
- 10.76782488800318 + 3.767824888003175im ,
64
63
- 6.291156861088416 - 0.7088431389115883im ,
65
64
- 4.298852160365525 - 2.7011478396344746im ,
66
65
- 2.7011478396344764 - 4.298852160365524im ,
67
66
- 0.7088431389115866 - 6.291156861088417im ,
68
67
3.767824888003177 - 10.76782488800318im ]),
69
- (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 , false ,
68
+ (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 ,
70
69
cat ([10.0 + 42.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ],
71
70
[26.0 + 58.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ],
72
71
dims= 3 )),
73
72
)
74
- x = array_constructor (_x) # dummy array that will be passed to plans
75
- x_complex = complex .(float .(x)) # for testing complex FFTs
76
- fftw_fft = array_constructor (_fftw_fft)
73
+ x = convert (ArrayType, _x) # dummy array that will be passed to plans
74
+ x_complex = convert (ArrayType, complex .(x)) # for testing complex FFTs
75
+ x_complexfloat = convert (ArrayType, complex .(float .(x))) # for in-place operations
76
+ fftw_fft = convert (ArrayType, _fftw_fft)
77
77
78
78
# FFT
79
79
y = AbstractFFTs. fft (x_complex, dims)
80
80
@test y ≈ fftw_fft
81
81
if test_inplace
82
- @test AbstractFFTs. fft! (copy (x_complex ), dims) ≈ fftw_fft
82
+ @test AbstractFFTs. fft! (copy (x_complexfloat ), dims) ≈ fftw_fft
83
83
end
84
84
# test plan_fft and also inv and plan_inv of plan_ifft, which should all give
85
85
# functionally identical plans
86
86
plans_to_test = [plan_fft (x, dims), inv (plan_ifft (x, dims)),
87
87
AbstractFFTs. plan_inv (plan_ifft (x, dims))]
88
88
for P in plans_to_test
89
- @test mul! (similar (y), P, copy (x_complex )) ≈ fftw_fft
89
+ @test mul! (similar (y), P, copy (x_complexfloat )) ≈ fftw_fft
90
90
end
91
91
if test_inplace
92
- push! (plans_to_test, plan_fft! (similar (x_complex ), dims))
92
+ push! (plans_to_test, plan_fft! (similar (x_complexfloat ), dims))
93
93
end
94
94
for P in plans_to_test
95
95
@test eltype (P) <: Complex
@@ -105,7 +105,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
105
105
@test AbstractFFTs. bfft! (copy (y), dims) ≈ fftw_bfft
106
106
end
107
107
P = plan_bfft (similar (y), dims)
108
- @test mul! (similar (x_complex ), P, copy (y)) ≈ fftw_bfft
108
+ @test mul! (similar (x_complexfloat ), P, copy (y)) ≈ fftw_bfft
109
109
plans_to_test = if test_inplace
110
110
[P, plan_bfft! (similar (y), dims)]
111
111
else
@@ -127,10 +127,10 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
127
127
plans_to_test = [plan_ifft (x, dims), inv (plan_fft (x, dims)),
128
128
AbstractFFTs. plan_inv (plan_fft (x, dims))]
129
129
for P in plans_to_test
130
- @test mul! (similar (x_complex ), P, copy (y)) ≈ fftw_ifft
130
+ @test mul! (similar (x_complexfloat ), P, copy (y)) ≈ fftw_ifft
131
131
end
132
132
if test_inplace
133
- push! (plans_to_test, plan_ifft! (similar (x_complex ), dims))
133
+ push! (plans_to_test, plan_ifft! (similar (x_complexfloat ), dims))
134
134
end
135
135
for P in plans_to_test
136
136
@test eltype (P) <: Complex
@@ -139,7 +139,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
139
139
@test fftdims (P) == dims
140
140
end
141
141
142
- if test_real && real_input
142
+ if test_real && (x isa Real)
143
143
x_real = float .(x) # for testing real FFTs
144
144
# RFFT
145
145
fftw_rfft = selectdim (fftw_fft, first (dims), 1 : (size (fftw_fft, first (dims)) ÷ 2 + 1 ))
0 commit comments