diff --git a/src/plan.jl b/src/plan.jl index b5e5d27..0c8b390 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -1,5 +1,6 @@ import Base: * import LinearAlgebra: mul! +using AbstractFFTs: to1, plan_fft, plan_bfft, plan_ifft, plan_rfft abstract type FFTAPlan{T,N} <: Plan{T} end @@ -20,7 +21,7 @@ struct FFTAPlan_re{T,N} <: FFTAPlan{T,N} flen::Int end -function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function __FFTA_plan_fft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -36,7 +37,7 @@ function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan end end -function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function __FFTA_plan_bfft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -52,7 +53,7 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPla end end -function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} +function __FFTA_plan_rfft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -68,7 +69,7 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPla end end -function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} +function __FFTA_plan_brfft(x::AbstractArray{T}, len, region = 1:ndims(x); kwargs...)::FFTAPlan_re{T} where {T} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -184,4 +185,16 @@ function *(p::FFTAPlan_re{T,N}, x::AbstractArray{T,2}) where {T<:Union{Real, Com LinearAlgebra.mul!(y, p, x_tmp) return y end -end \ No newline at end of file +end + +# Implement AbstractFFTs overloading +for f in (:fft, :bfft, :rfft, :brfft) + pf = Symbol("__FFTA_plan_", f) + abstract_pf = Symbol("plan_", f) + T_super = f == :rfft ? Real : Complex + @eval begin + @inline function AbstractFFTs.$abstract_pf(x::AbstractArray{T}, args...; kws...) where {T<:$T_super} + $pf(x, args...; kws...) + end + end +end