1+ module AbstractFFTsForwardDiffExt
2+
3+ using AbstractFFTs
4+ import ForwardDiff
5+ import ForwardDiff: Dual
6+ import AbstractFFTs: Plan
7+
8+ for P in (:Plan , :ScaledPlan ) # need ScaledPlan to avoid ambiguities
9+ @eval begin
10+ Base.:* (p:: AbstractFFTs. $ P, x:: AbstractArray{DT} ) where DT<: Dual = array2dual (DT, p * dual2array (x))
11+ Base.:* (p:: AbstractFFTs. $ P, x:: AbstractArray{<:Complex{DT}} ) where DT<: Dual = array2dual (DT, p * dual2array (x))
12+ end
13+ end
14+
15+ mul! (y:: AbstractArray{<:Union{Dual,Complex{<:Dual}}} , p:: Plan , x:: AbstractArray{<:Union{Dual,Complex{<:Dual}}} ) = copyto! (y, p* x)
16+
17+ AbstractFFTs. complexfloat (x:: AbstractArray{<:Dual} ) = AbstractFFTs. complexfloat .(x)
18+ AbstractFFTs. complexfloat (d:: Dual{T,V,N} ) where {T,V,N} = convert (Dual{T,float (V),N}, d) + 0im
19+
20+ AbstractFFTs. realfloat (x:: AbstractArray{<:Dual} ) = AbstractFFTs. realfloat .(x)
21+ AbstractFFTs. realfloat (d:: Dual{T,V,N} ) where {T,V,N} = convert (Dual{T,float (V),N}, d)
22+
23+ dual2array (x:: Array{<:Dual{Tag,T}} ) where {Tag,T} = reinterpret (reshape, T, x)
24+ dual2array (x:: Array{<:Complex{<:Dual{Tag, T}}} ) where {Tag,T} = complex .(dual2array (real (x)), dual2array (imag (x)))
25+ array2dual (DT:: Type{<:Dual} , x:: Array{T} ) where T = reinterpret (reshape, DT, real (x))
26+ array2dual (DT:: Type{<:Dual} , x:: Array{<:Complex{T}} ) where T = complex .(array2dual (DT, real (x)), array2dual (DT, imag (x)))
27+
28+
29+ for plan in (:plan_fft , :plan_ifft , :plan_bfft , :plan_rfft )
30+ @eval begin
31+ AbstractFFTs.$ plan (x:: AbstractArray{<:Dual} , dims= 1 : ndims (x)) = AbstractFFTs.$ plan (dual2array (x), 1 .+ dims)
32+ AbstractFFTs.$ plan (x:: AbstractArray{<:Complex{<:Dual}} , dims= 1 : ndims (x)) = AbstractFFTs.$ plan (dual2array (x), 1 .+ dims)
33+ end
34+ end
35+
36+
37+
38+ end # module
0 commit comments