From 058339c0f8e69392a02212156e7bee5e6aced320 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 6 May 2025 10:30:18 -0400 Subject: [PATCH 1/3] Add publicly accessible internal planning functions --- src/plan.jl | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index b5e5d27..d32f36e 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -20,7 +20,7 @@ struct FFTAPlan_re{T,N} <: FFTAPlan{T,N} flen::Int end -function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function __FFTA_plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -36,7 +36,11 @@ function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan end end -function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} + __FFTA_plan_fft(x::AbstractArray, region; kwargs...) +end + +function __FFTA_plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -52,7 +56,11 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPla end end -function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} +function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} + __FFTA_plan_bfft(x::AbstractArray, region; kwargs...) +end + +function __FFTA_plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -68,7 +76,11 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPla end end -function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} +function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} + __FFTA_plan_rfft(x::AbstractArray, region; kwargs...) +end + +function __FFTA_plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -84,6 +96,10 @@ function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::F end end +function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} + __FFTA_plan_brfft(x::AbstractArray, len, region; kwargs...) +end + function AbstractFFTs.plan_bfft(p::FFTAPlan_cx{T,N}) where {T,N} return FFTAPlan_cx{T,N}(p.callgraph, p.region, -p.dir, p.pinv) end From 0f67e0b189408880b321800bb34e45b218a5ee46 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 6 May 2025 13:26:53 -0400 Subject: [PATCH 2/3] Add working internal mechanism --- src/plan.jl | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index d32f36e..744a44e 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -1,5 +1,6 @@ import Base: * import LinearAlgebra: mul! +using AbstractFFTs: to1, plan_fft, plan_bfft, plan_ifft, plan_rfft abstract type FFTAPlan{T,N} <: Plan{T} end @@ -20,7 +21,7 @@ struct FFTAPlan_re{T,N} <: FFTAPlan{T,N} flen::Int end -function __FFTA_plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function __FFTA_plan_fft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -36,11 +37,7 @@ function __FFTA_plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} end end -function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} - __FFTA_plan_fft(x::AbstractArray, region; kwargs...) -end - -function __FFTA_plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} +function __FFTA_plan_bfft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_cx{T} where {T <: Complex} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -56,11 +53,7 @@ function __FFTA_plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T end end -function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_cx{T} where {T <: Complex} - __FFTA_plan_bfft(x::AbstractArray, region; kwargs...) -end - -function __FFTA_plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} +function __FFTA_plan_rfft(x::AbstractArray{T}, region = 1:ndims(x); kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -76,11 +69,7 @@ function __FFTA_plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{C end end -function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...)::FFTAPlan_re{Complex{T}} where {T <: Real} - __FFTA_plan_rfft(x::AbstractArray, region; kwargs...) -end - -function __FFTA_plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} +function __FFTA_plan_brfft(x::AbstractArray{T}, len, region = 1:ndims(x); kwargs...)::FFTAPlan_re{T} where {T} N = length(region) @assert N <= 2 "Only supports vectors and matrices" if N == 1 @@ -96,10 +85,6 @@ function __FFTA_plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPla end end -function AbstractFFTs.plan_brfft(x::AbstractArray{T}, len, region; kwargs...)::FFTAPlan_re{T} where {T} - __FFTA_plan_brfft(x::AbstractArray, len, region; kwargs...) -end - function AbstractFFTs.plan_bfft(p::FFTAPlan_cx{T,N}) where {T,N} return FFTAPlan_cx{T,N}(p.callgraph, p.region, -p.dir, p.pinv) end @@ -200,4 +185,16 @@ function *(p::FFTAPlan_re{T,N}, x::AbstractArray{T,2}) where {T<:Union{Real, Com LinearAlgebra.mul!(y, p, x_tmp) return y end -end \ No newline at end of file +end + +# Implement AbstractFFTs overloading +for f in (:fft, :bfft, :rfft, :brfft) + pf = Symbol("__FFTA_plan_", f) + abstract_pf = Symbol("plan_", f) + T_super = f == :rfft ? Real : Complex + @eval begin + function AbstractFFTs.$abstract_pf(x::AbstractArray{T}, args...; kws...) where {T<:$T_super} + $pf(x, args...; kws...) + end + end +end From f1dd1fe1437b3e012e1d918cbe0d63ace32a29db Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 6 May 2025 13:33:38 -0400 Subject: [PATCH 3/3] Add explicit inline request --- src/plan.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plan.jl b/src/plan.jl index 744a44e..0c8b390 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -193,7 +193,7 @@ for f in (:fft, :bfft, :rfft, :brfft) abstract_pf = Symbol("plan_", f) T_super = f == :rfft ? Real : Complex @eval begin - function AbstractFFTs.$abstract_pf(x::AbstractArray{T}, args...; kws...) where {T<:$T_super} + @inline function AbstractFFTs.$abstract_pf(x::AbstractArray{T}, args...; kws...) where {T<:$T_super} $pf(x, args...; kws...) end end