Skip to content

Commit cde9caa

Browse files
committed
add complex tests
1 parent 76b776e commit cde9caa

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

ext/AbstractFFTsForwardDiffExt.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,14 @@ end
4545

4646
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
4747
@eval begin
48-
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
49-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
48+
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...))
49+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims; kwds...))
5050
end
5151
end
5252

5353

5454
for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
55-
@eval begin
56-
AbstractFFTs.$plan(x::AbstractArray{D}, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), 1 .+ dims))
57-
AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x)) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims))
58-
end
55+
@eval AbstractFFTs.$plan(x::AbstractArray{<:Complex{D}}, d::Integer, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, AbstractFFTs.$plan(dual2array(x), d, 1 .+ dims; kwds...))
5956
end
6057

6158

test/abstractfftsforwarddiff.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k)
88

99
@testset "ForwardDiff extension tests" begin
1010
x1 = Dual.(1:4.0, 2:5, 3:6)
11+
c1 = Dual.(1:4.0, 2:5, 3:6) + im*Dual.(2:5.0, 3:6, 3:6)
1112

1213
@test AbstractFFTs.complexfloat(x1)[1] === Dual(1.0, 2.0, 3.0) + 0im
1314
@test AbstractFFTs.realfloat(x1)[1] === Dual(1.0, 2.0, 3.0)
@@ -54,7 +55,15 @@ complexpartials(x, k) = partials(real(x), k) + im*partials(imag(x), k)
5455
@test complexpartials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
5556
end
5657

57-
c1 = complex.(x1)
58-
@test mul!(similar(c1), plan_fft(x1), x1) == fft(x1)
59-
@test mul!(similar(c1), plan_fft(c1), c1) == fft(c1)
58+
@testset "complex" begin
59+
@test fft(c1) fft(real(c1)) + im*fft(imag(c1))
60+
dest = similar(c1)
61+
@test mul!(dest, plan_fft(x1), x1) == fft(x1) == dest
62+
@test mul!(dest, plan_fft(c1), c1) == fft(c1) == dest
63+
64+
C = c1 * ((1:10) .+ im*(2:11))'
65+
@test fft(C) fft(real(C)) + im*fft(imag(C))
66+
dest = similar(C)
67+
@test mul!(dest, plan_fft(C), C) == fft(C) == dest
68+
end
6069
end

0 commit comments

Comments
 (0)