diff --git a/src/RustFFT.jl b/src/RustFFT.jl index 1350650..65bf0b8 100644 --- a/src/RustFFT.jl +++ b/src/RustFFT.jl @@ -7,7 +7,7 @@ using .Internal @reexport using AbstractFFTs import Base: *, size -import AbstractFFTs: Plan, ScaledPlan, plan_fft, fft, plan_fft!, fft!, plan_bfft, plan_bfft!, +import AbstractFFTs: Plan, ScaledPlan, AbstractFFTBackend, plan_fft, fft, plan_fft!, fft!, plan_bfft, plan_bfft!, ifft, ifft!, fftdims, plan_inv import LinearAlgebra: mul! @@ -87,6 +87,15 @@ struct IgnoreArrayChecks <: ArrayChecks end # or `ComplexF32`. const RustFFTNumber = Union{Complex{Float64},Complex{Float32}} +export RustFFTBackend +struct RustFFTBackend <: AbstractFFTBackend end +backend() = RustFFTBackend() +activate!() = AbstractFFTs.set_active_backend!(RustFFT) + +function __init__() + activate!() +end + mutable struct RustFFTPlan{T<:RustFFTNumber,inplace,direction<:Direction,gcsafety<:GcSafety,arraychecks<:ArrayChecks} <: Plan{T} plan::FftInstance{T} pinv::ScaledPlan @@ -116,7 +125,7 @@ end return FftPlanner64() end -function plan_fft(x::Vector{T}, region; +function plan_fft(::RustFFTBackend, x::Vector{T}, region; rustfft_checks::arraychecks=IgnoreArrayTracking(), rustfft_gc_safe::gcsafety=GcUnsafe(), rustfft_planner::Union{FftPlanner{T},Nothing}=nothing, @@ -133,7 +142,7 @@ function plan_fft(x::Vector{T}, region; RustFFTPlan{T,false,Forward,gcsafety,arraychecks}(instance) end -function plan_fft!(x::Vector{T}, region; +function plan_fft!(::RustFFTBackend, x::Vector{T}, region; rustfft_checks::arraychecks=AllArrayChecks(), rustfft_gc_safe::gcsafety=GcUnsafe(), rustfft_planner::Union{FftPlanner{T},Nothing}=nothing, @@ -150,7 +159,7 @@ function plan_fft!(x::Vector{T}, region; RustFFTPlan{T,true,Forward,gcsafety,arraychecks}(instance) end -function plan_bfft(x::Vector{T}, region; +function plan_bfft(::RustFFTBackend, x::Vector{T}, region; rustfft_checks::arraychecks=IgnoreArrayTracking(), rustfft_gc_safe::gcsafety=GcUnsafe(), rustfft_planner::Union{FftPlanner{T},Nothing}=nothing, @@ -167,7 +176,7 @@ function plan_bfft(x::Vector{T}, region; RustFFTPlan{T,false,Backward,gcsafety,arraychecks}(instance) end -function plan_bfft!(x::Vector{T}, region; +function plan_bfft!(::RustFFTBackend, x::Vector{T}, region; rustfft_checks::arraychecks=AllArrayChecks(), rustfft_gc_safe::gcsafety=GcUnsafe(), rustfft_planner::Union{FftPlanner{T},Nothing}=nothing,