Skip to content

Commit 88acdba

Browse files
committed
Improve array construction in tests
1 parent b835f72 commit 88acdba

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

src/TestUtils.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,33 @@ using LinearAlgebra
1111
using Test
1212

1313
"""
14-
test_fft_backend(array_constructor; test_real=true, test_inplace=true)
14+
test_fft_backend(ArrayType=Array; test_real=true, test_inplace=true)
1515
1616
Run tests to verify correctness of FFT functions using a particular
1717
backend plan implementation. The backend implementation is assumed to be loaded
1818
prior to calling this function.
1919
2020
# Arguments
2121
22-
- `array_constructor`: determines the `AbstractArray` implementation for
23-
which the correctness tests are run. It is assumed to be a callable object that
24-
takes in input arrays of type `Array` and return arrays of the desired type for
25-
testing. For example, this can be a constructor such as `Array` or `CUDA.CuArray`.
22+
- `ArrayType`: determines the `AbstractArray` implementation for
23+
which the correctness tests are run. Arrays are constructed via
24+
`convert(ArrayType, ...)`.
2625
- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs.
2726
- `test_inplace=true`: whether to test in-place plans.
2827
"""
29-
function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
28+
function test_fft_backend(ArrayType=Array; test_real=true, test_inplace=true)
3029
@testset "fft correctness" begin
3130
# DFT along last dimension, results computed using FFTW
32-
for (_x, dims, real_input, _fftw_fft) in (
33-
(collect(1:7), 1, true,
31+
for (_x, dims, _fftw_fft) in (
32+
(collect(1:7), 1,
3433
[28.0 + 0.0im,
3534
-3.5 + 7.267824888003178im,
3635
-3.5 + 2.7911568610884143im,
3736
-3.5 + 0.7988521603655248im,
3837
-3.5 - 0.7988521603655248im,
3938
-3.5 - 2.7911568610884143im,
4039
-3.5 - 7.267824888003178im]),
41-
(collect(1:8), 1, true,
40+
(collect(1:8), 1,
4241
[36.0 + 0.0im,
4342
-4.0 + 9.65685424949238im,
4443
-4.0 + 4.0im,
@@ -47,49 +46,50 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
4746
-4.0 - 1.6568542494923806im,
4847
-4.0 - 4.0im,
4948
-4.0 - 9.65685424949238im]),
50-
(collect(reshape(1:8, 2, 4)), 2, true,
49+
(collect(reshape(1:8, 2, 4)), 2,
5150
[16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
5251
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
53-
(collect(reshape(1:9, 3, 3)), 2, true,
52+
(collect(reshape(1:9, 3, 3)), 2,
5453
[12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5554
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5655
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
57-
(collect(reshape(1:8, 2, 2, 2)), 1:2, true,
56+
(collect(reshape(1:8, 2, 2, 2)), 1:2,
5857
cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
5958
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
6059
dims=3)),
61-
(collect(1:7) + im * collect(8:14), 1, false,
60+
(collect(1:7) + im * collect(8:14), 1,
6261
[28.0 + 77.0im,
6362
-10.76782488800318 + 3.767824888003175im,
6463
-6.291156861088416 - 0.7088431389115883im,
6564
-4.298852160365525 - 2.7011478396344746im,
6665
-2.7011478396344764 - 4.298852160365524im,
6766
-0.7088431389115866 - 6.291156861088417im,
6867
3.767824888003177 - 10.76782488800318im]),
69-
(collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), 1:2, false,
68+
(collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), 1:2,
7069
cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
7170
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
7271
dims=3)),
7372
)
74-
x = array_constructor(_x) # dummy array that will be passed to plans
75-
x_complex = complex.(float.(x)) # for testing complex FFTs
76-
fftw_fft = array_constructor(_fftw_fft)
73+
x = convert(ArrayType, _x) # dummy array that will be passed to plans
74+
x_complex = convert(ArrayType, complex.(x)) # for testing complex FFTs
75+
x_complexfloat = convert(ArrayType, complex.(float.(x))) # for in-place operations
76+
fftw_fft = convert(ArrayType, _fftw_fft)
7777

7878
# FFT
7979
y = AbstractFFTs.fft(x_complex, dims)
8080
@test y fftw_fft
8181
if test_inplace
82-
@test AbstractFFTs.fft!(copy(x_complex), dims) fftw_fft
82+
@test AbstractFFTs.fft!(copy(x_complexfloat), dims) fftw_fft
8383
end
8484
# test plan_fft and also inv and plan_inv of plan_ifft, which should all give
8585
# functionally identical plans
8686
plans_to_test = [plan_fft(x, dims), inv(plan_ifft(x, dims)),
8787
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
8888
for P in plans_to_test
89-
@test mul!(similar(y), P, copy(x_complex)) fftw_fft
89+
@test mul!(similar(y), P, copy(x_complexfloat)) fftw_fft
9090
end
9191
if test_inplace
92-
push!(plans_to_test, plan_fft!(similar(x_complex), dims))
92+
push!(plans_to_test, plan_fft!(similar(x_complexfloat), dims))
9393
end
9494
for P in plans_to_test
9595
@test eltype(P) <: Complex
@@ -105,7 +105,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
105105
@test AbstractFFTs.bfft!(copy(y), dims) fftw_bfft
106106
end
107107
P = plan_bfft(similar(y), dims)
108-
@test mul!(similar(x_complex), P, copy(y)) fftw_bfft
108+
@test mul!(similar(x_complexfloat), P, copy(y)) fftw_bfft
109109
plans_to_test = if test_inplace
110110
[P, plan_bfft!(similar(y), dims)]
111111
else
@@ -127,10 +127,10 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
127127
plans_to_test = [plan_ifft(x, dims), inv(plan_fft(x, dims)),
128128
AbstractFFTs.plan_inv(plan_fft(x, dims))]
129129
for P in plans_to_test
130-
@test mul!(similar(x_complex), P, copy(y)) fftw_ifft
130+
@test mul!(similar(x_complexfloat), P, copy(y)) fftw_ifft
131131
end
132132
if test_inplace
133-
push!(plans_to_test, plan_ifft!(similar(x_complex), dims))
133+
push!(plans_to_test, plan_ifft!(similar(x_complexfloat), dims))
134134
end
135135
for P in plans_to_test
136136
@test eltype(P) <: Complex
@@ -139,7 +139,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
139139
@test fftdims(P) == dims
140140
end
141141

142-
if test_real && real_input
142+
if test_real && (x isa Real)
143143
x_real = float.(x) # for testing real FFTs
144144
# RFFT
145145
fftw_rfft = selectdim(fftw_fft, first(dims), 1:(size(fftw_fft, first(dims)) ÷ 2 + 1))

0 commit comments

Comments
 (0)