Skip to content

Commit 1078e0b

Browse files
committed
Add tests for multidimensional FFTs and complex input arrays
1 parent 41d216a commit 1078e0b

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

src/TestUtils.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ testing: this would most commonly be a constructor such as `Array` or `CuArray`.
2929
function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
3030
@testset "fft correctness" begin
3131
# DFT along last dimension, results computed using FFTW
32-
for (_x, _fftw_fft) in (
33-
(collect(1:7),
32+
for (_x, dims, real_input, _fftw_fft) in (
33+
(collect(1:7), 1, true,
3434
[28.0 + 0.0im,
3535
-3.5 + 7.267824888003178im,
3636
-3.5 + 2.7911568610884143im,
3737
-3.5 + 0.7988521603655248im,
3838
-3.5 - 0.7988521603655248im,
3939
-3.5 - 2.7911568610884143im,
4040
-3.5 - 7.267824888003178im]),
41-
(collect(1:8),
41+
(collect(1:8), 1, true,
4242
[36.0 + 0.0im,
4343
-4.0 + 9.65685424949238im,
4444
-4.0 + 4.0im,
@@ -47,21 +47,32 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
4747
-4.0 - 1.6568542494923806im,
4848
-4.0 - 4.0im,
4949
-4.0 - 9.65685424949238im]),
50-
(collect(reshape(1:8, 2, 4)),
50+
(collect(reshape(1:8, 2, 4)), 2, true,
5151
[16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
5252
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
53-
(collect(reshape(1:9, 3, 3)),
53+
(collect(reshape(1:9, 3, 3)), 2, true,
5454
[12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5555
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5656
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
57+
(collect(reshape(1:8, 2, 2, 2)), 1:2, true,
58+
[10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im;;;
59+
26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im]),
60+
(collect(1:7) + im * collect(8:14), 1, false,
61+
[28.0 + 77.0im,
62+
-10.76782488800318 + 3.767824888003175im,
63+
-6.291156861088416 - 0.7088431389115883im,
64+
-4.298852160365525 - 2.7011478396344746im,
65+
-2.7011478396344764 - 4.298852160365524im,
66+
-0.7088431389115866 - 6.291156861088417im,
67+
3.767824888003177 - 10.76782488800318im]),
68+
(collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), 1:2, false,
69+
[10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im;;;
70+
26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im]),
5771
)
5872
x = array_constructor(_x) # dummy array that will be passed to plans
59-
x_real = float.(x) # for testing real FFTs
60-
x_complex = complex.(x_real) # for testing complex FFTs
73+
x_complex = complex.(float.(x)) # for testing complex FFTs
6174
fftw_fft = array_constructor(_fftw_fft)
6275

63-
dims = ndims(x) # TODO: this is a single dimension, should check multidimensional FFTs too
64-
6576
# FFT
6677
y = AbstractFFTs.fft(x_complex, dims)
6778
@test y fftw_fft
@@ -82,7 +93,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
8293
end
8394

8495
# BFFT
85-
fftw_bfft = size(x_complex, dims) .* x_complex
96+
fftw_bfft = prod(size(x_complex, d) for d in dims) .* x_complex
8697
@test AbstractFFTs.bfft(y, dims) fftw_bfft
8798
test_inplace && (@test AbstractFFTs.bfft!(copy(y), dims) fftw_bfft)
8899
plans_to_test = [plan_bfft(similar(y), dims)]
@@ -114,16 +125,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
114125
@test fftdims(P) == dims
115126
end
116127

117-
if test_real
128+
if test_real && real_input
129+
x_real = float.(x) # for testing real FFTs
118130
# RFFT
119131
fftw_rfft = fftw_fft[
120-
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
121-
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
132+
(Colon() for _ in 1:(first(dims) - 1))...,
133+
1:(size(fftw_fft, first(dims)) ÷ 2 + 1),
134+
(Colon() for _ in (first(dims) + 1):ndims(fftw_fft))...
122135
]
123136
ry = AbstractFFTs.rfft(x_real, dims)
124137
@test ry fftw_rfft
125-
for P in [plan_rfft(x_real, dims), inv(plan_irfft(ry, size(x, dims), dims)),
126-
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
138+
for P in [plan_rfft(x_real, dims), inv(plan_irfft(ry, size(x, first(dims)), dims)),
139+
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, first(dims)), dims))]
127140
@test eltype(P) <: Real
128141
@test P * x_real fftw_rfft
129142
@test mul!(similar(ry), P, x_real) fftw_rfft
@@ -132,18 +145,18 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
132145
end
133146

134147
# BRFFT
135-
fftw_brfft = complex.(size(x, dims) .* x_real)
136-
@test AbstractFFTs.brfft(ry, size(x_real, dims), dims) fftw_brfft
137-
P = plan_brfft(ry, size(x_real, dims), dims)
148+
fftw_brfft = prod(size(x_real, d) for d in dims) .* x_real
149+
@test AbstractFFTs.brfft(ry, size(x_real, first(dims)), dims) fftw_brfft
150+
P = plan_brfft(ry, size(x_real, first(dims)), dims)
138151
@test P * ry fftw_brfft
139152
@test mul!(similar(x_real), P, ry) fftw_brfft
140153
@test P \ (P * ry) ry
141154
@test fftdims(P) == dims
142155

143156
# IRFFT
144157
fftw_irfft = x_complex
145-
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
146-
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x_real, dims)),
158+
@test AbstractFFTs.irfft(ry, size(x, first(dims)), dims) fftw_irfft
159+
for P in [plan_irfft(ry, size(x, first(dims)), dims), inv(plan_rfft(x_real, dims)),
147160
AbstractFFTs.plan_inv(plan_rfft(x_real, dims))]
148161
@test P * ry fftw_irfft
149162
@test mul!(similar(x_real), P, ry) fftw_irfft

0 commit comments

Comments
 (0)