Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 170 additions & 86 deletions src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ abstract type FFTAPlan{T,N} <: AbstractFFTs.Plan{T} end

struct FFTAInvPlan{T,N} <: FFTAPlan{T,N} end

struct FFTAPlan_cx{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
const RegionTypes{N} = Union{Int,AbstractVector{Int},NTuple{N,Int}}

struct FFTAPlan_cx{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
Expand All @@ -13,11 +15,11 @@ end
function FFTAPlan_cx{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
) where {T,N,R<:RegionTypes{N}}
FFTAPlan_cx{T,N,R}(cg, r, dir, pinv)
end

struct FFTAPlan_re{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
struct FFTAPlan_re{T,N,R<:RegionTypes{N}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
Expand All @@ -27,7 +29,7 @@ end
function FFTAPlan_re{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}, flen::Int
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
) where {T,N,R<:RegionTypes{N}}
FFTAPlan_re{T,N,R}(cg, r, dir, pinv, flen)
end

Expand All @@ -46,37 +48,66 @@ Base.size(p::FFTAPlan{<:Any,N}) where N = ntuple(Base.Fix1(size, p), Val{N}())

Base.complex(p::FFTAPlan_re{T,N,R}) where {T,N,R} = FFTAPlan_cx{T,N,R}(p.callgraph, p.region, p.dir, p.pinv)

AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
function _sort(region::T)::T where {N,T<:NTuple{N,Int}}
@static if VERSION >= v"1.12"
sort(region)
else
if N == 2
minmax(region[1], region[2])
elseif N == 3
t1, t2, t3 = region
t1, t2 = minmax(t1, t2)
t2, t3 = minmax(t2, t3)
t1, t2 = minmax(t1, t2)
(t1, t2, t3)
else
NTuple{N}(sort!(collect(region)))
end
end
end

_sort(region::T) where T<:RegionTypes = issorted(region) ? copy(region) : sort(region)

AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region; kwargs...) where {T<:Complex,N} =
_plan_fft(x, region, FFT_FORWARD; kwargs...)

AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region; kwargs...) where {T<:Complex,N} =
_plan_fft(x, region, FFT_BACKWARD; kwargs...)

function _plan_fft(x::AbstractArray{T,N}, region::R, dir::Direction; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Complex,N,R}
FFTN = length(region)
if FFTN == 1
function _plan_fft(
x::AbstractArray{T,N},
region::RegionTypes,
dir::Direction;
BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...
) where {T<:Complex,N}
M = length(region)
if M == 1
R1 = Int(region[])
g = CallGraph{T}(size(x, R1), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(size(x, region[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF)
elseif M == 2
R2 = _sort(region)
g1 = CallGraph{T}(size(x, R2[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_cx{T,2,R}((g1, g2), region, dir, pinv)
return FFTAPlan_cx{T,2,typeof(R2)}((g1, g2), R2, dir, pinv)
else
sort!(region)
return FFTAPlan_cx{T,FFTN,R}(
ntuple(i -> CallGraph{T}(size(x, region[i]), BLUESTEIN_CUTOFF), Val(FFTN)),
region, dir, FFTAInvPlan{T,FFTN}()
RM = _sort(region)
return FFTAPlan_cx{T,M,typeof(RM)}(
ntuple(i -> CallGraph{T}(size(x, RM[i]), BLUESTEIN_CUTOFF), Val(M)),
RM, dir, FFTAInvPlan{T,M}()
)
end
end

function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Real,N,R}
FFTN = length(region)
if FFTN == 1
function AbstractFFTs.plan_rfft(
x::AbstractArray{T,N},
region::RegionTypes;
BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...
) where {T<:Real,N}
M = length(region)
if M == 1
R1 = Int(region[])
n = size(x, R1)
# For even length problems, we solve the real problem with
Expand All @@ -86,20 +117,25 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; BLUESTEIN_CUTO
g = CallGraph{Complex{T}}(nn, BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{Complex{T},1}()
return FFTAPlan_re{Complex{T},1,Int}((g,), R1, FFT_FORWARD, pinv, n)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{Complex{T}}(size(x, region[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{Complex{T}}(size(x, region[2]), BLUESTEIN_CUTOFF)
elseif M == 2
R2 = _sort(region)
g1 = CallGraph{Complex{T}}(size(x, R2[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{Complex{T}}(size(x, R2[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{Complex{T},2}()
return FFTAPlan_re{Complex{T},2,R}((g1, g2), region, FFT_FORWARD, pinv, size(x, region[1]))
return FFTAPlan_re{Complex{T},2,typeof(R2)}((g1, g2), R2, FFT_FORWARD, pinv, size(x, R2[1]))
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end

function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T,N,R}
FFTN = length(region)
if FFTN == 1
function AbstractFFTs.plan_brfft(
x::AbstractArray{T,N},
len::Int,
region::RegionTypes;
BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...
) where {T,N}
M = length(region)
if M == 1
# For even length problems, we solve the real problem with
# two n/2 complex FFTs followed by a butterfly. For odd size
# problems, we just solve the problem as a single complex
Expand All @@ -108,12 +144,12 @@ function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; BLUESTEI
g = CallGraph{T}(nn, BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_re{T,1,Int}((g,), R1, FFT_BACKWARD, pinv, len)
elseif FFTN == 2
sort!(region)
elseif M == 2
R2 = _sort(region)
g1 = CallGraph{T}(len, BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, R2[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_re{T,2,R}((g1, g2), region, FFT_BACKWARD, pinv, len)
return FFTAPlan_re{T,2,typeof(R2)}((g1, g2), R2, FFT_BACKWARD, pinv, len)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
Expand All @@ -139,70 +175,89 @@ end
#### 1D plan ND array
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::AbstractArray{T,N}) where {T,U,N}
Base.require_one_based_indexing(x, y)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))

ax_x, ax_y = axes(x), axes(y)
if ax_x != ax_y
throw(DimensionMismatch("input array has axes $ax_x, but output array has axes $ax_y"))
end
if size(p, 1) != size(x, p.region[])
throw(DimensionMismatch("plan has size $(size(p, 1)), but input array has size $(size(x, p.region[])) along region $(p.region[])"))

R1 = p.region[]
plen, xlen = size(p, 1), size(x, R1)
if plen != xlen
throw(DimensionMismatch("plan has size $plen, but input array has size $xlen along region $R1"))
end

if @generated
quote
Base.Cartesian.@nif $N d -> (d == R1) dim -> (_mul_loop!(y, x, p, Val(dim)))
end
else
_mul_loop!(y, x, p, Val(R1))
end
Rpre = CartesianIndices(size(x)[1:p.region[]-1])
Rpost = CartesianIndices(size(x)[p.region[]+1:end])
_mul_loop!(y, x, Rpre, Rpost, p)
return y
end

function _mul_loop!(
y::AbstractArray{U,N},
x::AbstractArray{T,N},
Rpre::CartesianIndices,
Rpost::CartesianIndices,
p::FFTAPlan_cx{T,1}) where {T,U,N}
p::FFTAPlan_cx{T,1},
::Val{R}
) where {T,U,N,R}
Rpre = CartesianIndices(ntuple(Base.Fix1(size, x), Val(R - 1)))
Rpost = CartesianIndices(ntuple(i -> size(x, R + i), Val(N - R)))
for Ipost in Rpost, Ipre in Rpre
@views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
end
end

#### ND plan ND array
@generated function LinearAlgebra.mul!(
function LinearAlgebra.mul!(
out::AbstractArray{U,N},
p::FFTAPlan_cx{T,N},
X::AbstractArray{T,N}
) where {T,U,N}
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif size(p) != size(X)
throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))"))
elseif !(p.region == 1:N || p.region == 1)
throw(DimensionMismatch("Plan region is outside array dimensions."))
end

quote
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif size(p) != size(X)
throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))"))
elseif !(p.region == N || p.region == 1:N)
throw(DimensionMismatch("Plan region is outside array dimensions."))
end
sz = size(X)
max_sz = maximum(sz)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir

sz = size(X)
max_sz = maximum(sz)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir
copyto!(out, X) # operate in-place on output array

copyto!(out, X) # operate in-place on output array
if @generated
quote
Base.Cartesian.@nexprs $N dim -> begin
n = size(out, dim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]

Base.Cartesian.@nexprs $N dim -> begin
fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim))
end
end
else
for dim in 1:N
n = size(out, dim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]

Rpre_{dim} = CartesianIndices(sz[1:dim-1])
Rpost_{dim} = CartesianIndices(sz[dim+1:N])

fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre_{dim}, Rpost_{dim})
fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim))
end

return out
end

return out
end

#### MD plan ND array (M<N)
Expand All @@ -214,46 +269,75 @@ function LinearAlgebra.mul!(
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif length(p.region) != M || !issorted(p.region; lt=(<=))
throw(DimensionMismatch("Region is invalid."))
elseif M > N || first(p.region) < 1 || last(p.region) > N
throw(DimensionMismatch("Plan region is outside array dimensions."))
end

sz = size(X)
max_sz = maximum(Base.Fix1(size, out), p.region)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir

copyto!(out, X) # operate in-place on output array

# don't use generated functions because this cannot be type-stable anyway
for dim in 1:M
pdim = p.region[dim]
n = size(out, pdim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]
_execute_mdfft!(out, ibuf, obuf, p.dir, p.region, p.callgraph)

Rpre = CartesianIndices(sz[1:pdim-1])
Rpost = CartesianIndices(sz[pdim+1:N])
return out
end

fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre, Rpost)
end
@noinline function _execute_mdfft!(
out::AbstractArray{U,N},
ibuf::Vector{T}, obuf::Vector{T},
dir::Direction,
@nospecialize(region::RegionTypes),
@nospecialize(callgraphs::NTuple)
) where {T,U,N}

return out
M = length(region)
if @generated
quote
k = 1
# region is assumed to be pre-sorted during planning
Base.Cartesian.@nexprs $N dim -> begin
if region[k] == dim
n = size(out, dim)
resize!(obuf, n)
resize!(ibuf, n)
cg = callgraphs[k]

fft_along_dim!(out, ibuf, obuf, cg, dir, Val(dim))

k = min(k + 1, M)
end
end
return nothing
end
else
for dim in 1:M
pdim = region[dim]
n = size(out, pdim)
resize!(obuf, n)
resize!(ibuf, n)
cg = callgraphs[dim]

fft_along_dim!(out, ibuf, obuf, cg, dir, Val(pdim))
end
end
end

function fft_along_dim!(
A::AbstractArray,
A::AbstractArray{U,N},
ibuf::Vector{T}, obuf::Vector{T},
cg::CallGraph{T}, d::Direction,
Rpre::CartesianIndices{M}, Rpost::CartesianIndices
) where {T <: Complex{<:AbstractFloat}, M}
::Val{dim}
) where {T <: Complex{<:AbstractFloat}, U, N, dim}

Rpre = CartesianIndices(ntuple(Base.Fix1(size, A), Val(dim - 1)))
Rpost = CartesianIndices(ntuple(i -> size(A, dim + i), Val(N - dim)))
t = cg[1].type
dim = M + 1
cols = eachindex(axes(A, dim), ibuf, obuf)

for Ipost in Rpost, Ipre in Rpre
Expand Down
Loading
Loading