Skip to content

Commit 9fef024

Browse files
committed
Add backend-based interface
1 parent 04a14f5 commit 9fef024

File tree

1 file changed

+77
-22
lines changed

1 file changed

+77
-22
lines changed

src/definitions.jl

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,55 +58,95 @@ to1(x::AbstractArray) = _to1(axes(x), x)
5858
_to1(::Tuple{Base.OneTo,Vararg{Base.OneTo}}, x) = x
5959
_to1(::Tuple, x) = copy1(eltype(x), x)
6060

61+
# Abstract FFT Backend
62+
export AbstractFFTBackend
63+
abstract type AbstractFFTBackend end
64+
const ACTIVE_BACKEND = Ref{Union{Missing, AbstractFFTBackend}}(missing)
65+
66+
"""
67+
set_active_backend!(back::Union{Missing, Module, AbstractFFTBackend})
68+
69+
Set the default FFT plan backend. A module `back` must implement `back.backend()`.
70+
"""
71+
set_active_backend!(back::Module) = set_active_backend!(back.backend())
72+
function set_active_backend!(back::Union{Missing, AbstractFFTBackend})
73+
ACTIVE_BACKEND[] = back
74+
end
75+
active_backend() = ACTIVE_BACKEND[]
76+
function no_backend_error()
77+
error(
78+
"""
79+
No default backend available!
80+
Make sure to also "import/using" an FFT backend such as FFTW, FFTA or RustFFT.
81+
"""
82+
)
83+
end
84+
85+
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft, :brfft, :irfft)
86+
pf = Symbol("plan_", f)
87+
@eval begin
88+
$f(x::AbstractArray, args...; kws...) = $f(active_backend(), x, args...; kws...)
89+
$pf(x::AbstractArray, args...; kws...) = $pf(active_backend(), x, args...; kws...)
90+
$f(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
91+
$pf(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
92+
end
93+
end
6194
# implementations only need to provide plan_X(x, region)
6295
# for X in (:fft, :bfft, ...):
6396
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
6497
pf = Symbol("plan_", f)
6598
@eval begin
66-
$f(x::AbstractArray) = $f(x, 1:ndims(x))
67-
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
68-
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
99+
$f(b::AbstractFFTBackend, x::AbstractArray) = $f(b, x, 1:ndims(x))
100+
$f(b::AbstractFFTBackend, x::AbstractArray, region) = (y = to1(x); $pf(b, y, region) * y)
101+
$pf(b::AbstractFFTBackend, x::AbstractArray; kws...) = (y = to1(x); $pf(b, y, 1:ndims(y); kws...))
69102
end
70103
end
71104

72105
"""
106+
plan_ifft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
73107
plan_ifft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
74108
75109
Same as [`plan_fft`](@ref), but produces a plan that performs inverse transforms
76-
[`ifft`](@ref).
110+
[`ifft`](@ref). Uses active `backend` if no explicit `backend` is provided.
77111
"""
78112
plan_ifft
79113

80114
"""
115+
plan_ifft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
81116
plan_ifft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
82117
83118
Same as [`plan_ifft`](@ref), but operates in-place on `A`.
84119
"""
85120
plan_ifft!
86121

87122
"""
123+
plan_bfft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
88124
plan_bfft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
89125
90126
Same as [`plan_bfft`](@ref), but operates in-place on `A`.
91127
"""
92128
plan_bfft!
93129

94130
"""
131+
plan_bfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
95132
plan_bfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
96133
97134
Same as [`plan_fft`](@ref), but produces a plan that performs an unnormalized
98-
backwards transform [`bfft`](@ref).
135+
backwards transform [`bfft`](@ref). Uses active `backend` if no explicit `backend` is provided.
99136
"""
100137
plan_bfft
101138

102139
"""
140+
plan_fft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
103141
plan_fft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
104142
105143
Pre-plan an optimized FFT along given dimensions (`dims`) of arrays matching the shape and
106144
type of `A`. (The first two arguments have the same meaning as for [`fft`](@ref).)
107145
Returns an object `P` which represents the linear operator computed by the FFT, and which
108146
contains all of the information needed to compute `fft(A, dims)` quickly.
109147
148+
Uses active `backend` if no explicit `backend` is provided.
149+
110150
To apply `P` to an array `A`, use `P * A`; in general, the syntax for applying plans is much
111151
like that of matrices. (A plan can only be applied to arrays of the same size as the `A`
112152
for which the plan was created.) You can also apply a plan with a preallocated output array `Â`
@@ -132,34 +172,40 @@ plans that perform the equivalent of the inverse transforms [`ifft`](@ref) and s
132172
plan_fft
133173

134174
"""
175+
plan_fft!(backend A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
135176
plan_fft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
136177
137178
Same as [`plan_fft`](@ref), but operates in-place on `A`.
138179
"""
139180
plan_fft!
140181

141182
"""
183+
rfft(backend, A [, dims])
142184
rfft(A [, dims])
143185
144186
Multidimensional FFT of a real array `A`, exploiting the fact that the transform has
145187
conjugate symmetry in order to save roughly half the computational time and storage costs
146188
compared with [`fft`](@ref). If `A` has size `(n_1, ..., n_d)`, the result has size
147189
`(div(n_1,2)+1, ..., n_d)`.
148190
191+
Uses active `backend` if no explicit `backend` is provided.
192+
149193
The optional `dims` argument specifies an iterable subset of one or more dimensions of `A`
150194
to transform, similar to [`fft`](@ref). Instead of (roughly) halving the first
151195
dimension of `A` in the result, the `dims[1]` dimension is (roughly) halved in the same way.
152196
"""
153197
rfft
154198

155199
"""
200+
ifft!(backend, A [, dims])
156201
ifft!(A [, dims])
157202
158203
Same as [`ifft`](@ref), but operates in-place on `A`.
159204
"""
160205
ifft!
161206

162207
"""
208+
ifft(backend, A [, dims])
163209
ifft(A [, dims])
164210
165211
Multidimensional inverse FFT.
@@ -177,6 +223,7 @@ A multidimensional inverse FFT simply performs this operation along each transfo
177223
ifft
178224

179225
"""
226+
fft!(backend, A [, dims])
180227
fft!(A [, dims])
181228
182229
Same as [`fft`](@ref), but operates in-place on `A`, which must be an array of
@@ -185,6 +232,7 @@ complex floating-point numbers.
185232
fft!
186233

187234
"""
235+
bfft(backend, A [, dims])
188236
bfft(A [, dims])
189237
190238
Similar to [`ifft`](@ref), but computes an unnormalized inverse (backward)
@@ -200,6 +248,7 @@ computational steps elsewhere.)
200248
bfft
201249

202250
"""
251+
bfft!(backend, A [, dims])
203252
bfft!(A [, dims])
204253
205254
Same as [`bfft`](@ref), but operates in-place on `A`.
@@ -211,14 +260,14 @@ bfft!
211260
for f in (:fft, :bfft, :ifft)
212261
pf = Symbol("plan_", f)
213262
@eval begin
214-
$f(x::AbstractArray{<:Real}, region) = $f(complexfloat(x), region)
215-
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
216-
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
217-
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
263+
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region) = $f(b, complexfloat(x), region)
264+
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
265+
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(b, complexfloat(x), region)
266+
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
218267
end
219268
end
220-
rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region)
221-
plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...)
269+
rfft(b::AbstractFFTBackend, x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(b, realfloat(x), region)
270+
plan_rfft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = plan_rfft(b, realfloat(x), region; kws...)
222271

223272
# only require implementation to provide *(::Plan{T}, ::Array{T})
224273
*(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x)
@@ -279,10 +328,10 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
279328
end
280329
normalization(X, region) = normalization(real(eltype(X)), size(X), region)
281330

282-
plan_ifft(x::AbstractArray, region; kws...) =
283-
ScaledPlan(plan_bfft(x, region; kws...), normalization(x, region))
284-
plan_ifft!(x::AbstractArray, region; kws...) =
285-
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))
331+
plan_ifft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
332+
ScaledPlan(plan_bfft(b, x, region; kws...), normalization(x, region))
333+
plan_ifft!(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
334+
ScaledPlan(plan_bfft!(b, x, region; kws...), normalization(x, region))
286335

287336
plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
288337
# Don't cache inverse of scaled plan (only inverse of inner plan)
@@ -302,20 +351,21 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
302351
for f in (:brfft, :irfft)
303352
pf = Symbol("plan_", f)
304353
@eval begin
305-
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
306-
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
307-
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
354+
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer) = $f(b, x, d, 1:ndims(x))
355+
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer, region) = $pf(b, x, d, region) * x
356+
$pf(b::AbstractFFTBackend, x::AbstractArray, d::Integer;kws...) = $pf(b, x, d, 1:ndims(x);kws...)
308357
end
309358
end
310359

311360
for f in (:brfft, :irfft)
312361
@eval begin
313-
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
314-
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
362+
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, d::Integer, region) = $f(b, complexfloat(x), d, region)
363+
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(b, complexfloat(x), d, region)
315364
end
316365
end
317366

318367
"""
368+
irfft(backend, A, d [, dims])
319369
irfft(A, d [, dims])
320370
321371
Inverse of [`rfft`](@ref): for a complex array `A`, gives the corresponding real
@@ -330,6 +380,7 @@ transformed real array.)
330380
irfft
331381

332382
"""
383+
brfft(backend, A, d [, dims])
333384
brfft(A, d [, dims])
334385
335386
Similar to [`irfft`](@ref) but computes an unnormalized inverse transform (similar
@@ -351,11 +402,12 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
351402
return ntuple(i -> i == d1 ? d : sz[i], Val(N))
352403
end
353404

354-
plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
355-
ScaledPlan(plan_brfft(x, d, region; kws...),
405+
plan_irfft(b::AbstractFFTBackend, x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
406+
ScaledPlan(plan_brfft(b, x, d, region; kws...),
356407
normalization(T, brfft_output_size(x, d, region), region))
357408

358409
"""
410+
plan_irfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
359411
plan_irfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
360412
361413
Pre-plan an optimized inverse real-input FFT, similar to [`plan_rfft`](@ref)
@@ -543,6 +595,7 @@ fftshift(x::Frequencies) = (x.n_nonnegative-x.n:x.n_nonnegative-1)*x.multiplier
543595
##############################################################################
544596

545597
"""
598+
fft(backend, A [, dims])
546599
fft(A [, dims])
547600
548601
Performs a multidimensional FFT of the array `A`. The optional `dims` argument specifies an
@@ -570,6 +623,7 @@ A multidimensional FFT simply performs this operation along each transformed dim
570623
fft
571624

572625
"""
626+
plan_rfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
573627
plan_rfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
574628
575629
Pre-plan an optimized real-input FFT, similar to [`plan_fft`](@ref) except for
@@ -579,6 +633,7 @@ size of the transformed result, are the same as for [`rfft`](@ref).
579633
plan_rfft
580634

581635
"""
636+
plan_brfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
582637
plan_brfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
583638
584639
Pre-plan an optimized real-input unnormalized transform, similar to

0 commit comments

Comments
 (0)