@@ -4,7 +4,7 @@ using AbstractFFTs
44using AbstractFFTs. LinearAlgebra
55import ForwardDiff
66import ForwardDiff: Dual
7- import AbstractFFTs: Plan, mul!
7+ import AbstractFFTs: Plan, mul!, dualplan, dual2array
88
99
1010AbstractFFTs. complexfloat (x:: AbstractArray{<:Dual} ) = AbstractFFTs. complexfloat .(x)
3232
3333DualPlan (:: Type{Dual{Tag,V,N}} , p:: Plan{T} ) where {Tag,T<: Real ,V,N} = DualPlan {Dual{Tag,T,N},typeof(p)} (p)
3434DualPlan (:: Type{Dual{Tag,V,N}} , p:: Plan{Complex{T}} ) where {Tag,T<: Real ,V,N} = DualPlan {Complex{Dual{Tag,T,N}},typeof(p)} (p)
35+ dualplan (D, p) = DualPlan (D, p)
3536Base. size (p:: DualPlan ) = Base. tail (size (p. p))
3637Base.:* (p:: DualPlan{DT} , x:: AbstractArray{DT} ) where DT<: Dual = array2dual (DT, p. p * dual2array (x))
3738Base.:* (p:: DualPlan{Complex{DT}} , x:: AbstractArray{Complex{DT}} ) where DT<: Dual = array2dual (DT, p. p * dual2array (x))
4849
4950for plan in (:plan_fft , :plan_ifft , :plan_bfft , :plan_rfft )
5051 @eval begin
51- AbstractFFTs.$ plan (x:: AbstractArray{D} , dims= 1 : ndims (x)) where D<: Dual = DualPlan (D, AbstractFFTs.$ plan (dual2array (x), 1 .+ dims))
52- AbstractFFTs.$ plan (x:: AbstractArray{<:Complex{D}} , dims= 1 : ndims (x)) where D<: Dual = DualPlan (D, AbstractFFTs.$ plan (dual2array (x), 1 .+ dims))
52+ AbstractFFTs.$ plan (x:: AbstractArray{D} , dims= 1 : ndims (x)) where D<: Dual = dualplan (D, AbstractFFTs.$ plan (dual2array (x), 1 .+ dims))
53+ AbstractFFTs.$ plan (x:: AbstractArray{<:Complex{D}} , dims= 1 : ndims (x)) where D<: Dual = dualplan (D, AbstractFFTs.$ plan (dual2array (x), 1 .+ dims))
5354 end
5455end
5556
5657
58+ for plan in (:plan_irfft , :plan_brfft ) # these take an extra argument, only when complex?
59+ @eval begin
60+ AbstractFFTs.$ plan (x:: AbstractArray{D} , dims= 1 : ndims (x)) where D<: Dual = dualplan (D, AbstractFFTs.$ plan (dual2array (x), 1 .+ dims))
61+ AbstractFFTs.$ plan (x:: AbstractArray{<:Complex{D}} , d:: Integer , mdims= 1 : ndims (x)) where D<: Dual = dualplan (D, AbstractFFTs.$ plan (dual2array (x), d, 1 .+ dims))
62+ end
63+ end
64+
5765
5866end # module
0 commit comments