Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 104 additions & 39 deletions src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ function _plan_fft(x::AbstractArray{T,N}, region::R, dir::Direction; BLUESTEIN_C
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_cx{T,2,R}((g1, g2), region, dir, pinv)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
sort!(region)
return FFTAPlan_cx{T,FFTN,R}(
ntuple(i -> CallGraph{T}(size(x, region[i]), BLUESTEIN_CUTOFF), Val(FFTN)),
region, dir, FFTAInvPlan{T,FFTN}()
)
end
end

Expand Down Expand Up @@ -129,6 +133,7 @@ end
### Complex
#### 1D plan 1D array
function LinearAlgebra.mul!(y::AbstractVector{U}, p::FFTAPlan_cx{T,1}, x::AbstractVector{T}) where {T,U}
Base.require_one_based_indexing(x, y)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))
end
Expand All @@ -141,14 +146,14 @@ end

#### 1D plan ND array
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::AbstractArray{T,N}) where {T,U,N}
Base.require_one_based_indexing(x)
Base.require_one_based_indexing(x, y)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))
end
if size(p, 1) != size(x, p.region[])
throw(DimensionMismatch("plan has size $(size(p, 1)), but input array has size $(size(x, p.region[])) along region $(p.region[])"))
end
Rpre = CartesianIndices(size(x)[1:p.region[]-1])
Rpre = CartesianIndices(size(x)[1:p.region[]-1])
Rpost = CartesianIndices(size(x)[p.region[]+1:end])
_mul_loop!(y, x, Rpre, Rpost, p)
return y
Expand All @@ -165,52 +170,112 @@ function _mul_loop!(
end
end

#### 2D plan ND array
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,2}, x::AbstractArray{T,N}) where {T,U,N}
Base.require_one_based_indexing(x)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))
#### ND plan ND array
@generated function LinearAlgebra.mul!(
out::AbstractArray{U,N},
p::FFTAPlan_cx{T,N},
X::AbstractArray{T,N}
) where {T,U,N}

quote
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif size(p) != size(X)
throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))"))
elseif !(p.region == N || p.region == 1:N)
throw(DimensionMismatch("Plan region is outside array dimensions."))
end

sz = size(X)
max_sz = maximum(sz)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir

copyto!(out, X) # operate in-place on output array

Base.Cartesian.@nexprs $N dim -> begin
n = size(out, dim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]

Rpre_{dim} = CartesianIndices(sz[1:dim-1])
Rpost_{dim} = CartesianIndices(sz[dim+1:N])

fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre_{dim}, Rpost_{dim})
end

return out
end
if N < 2
throw(DimensionMismatch("array dimension $N cannot be smaller than the plan size 2"))
end

#### MD plan ND array (M<N)
function LinearAlgebra.mul!(
out::AbstractArray{U,N},
p::FFTAPlan_cx{T,M},
X::AbstractArray{T,N}
) where {T,U,N,M}
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif M > N || p.region[1] < 1 || p.region[end] > N
throw(DimensionMismatch("Plan region is outside array dimensions."))
end
if size(p) != (size(x, p.region[1]), size(x, p.region[2]))
throw(DimensionMismatch("plan has size $(size(p)), but input array has size $((size(x, p.region[1]), size(x, p.region[2]))) along regions $(p.region)"))

sz = size(X)
max_sz = maximum(Base.Fix1(size, out), p.region)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir

copyto!(out, X) # operate in-place on output array

# don't use generated functions because this cannot be type-stable anyway
for dim in 1:M
pdim = p.region[dim]
n = size(out, pdim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]

Rpre = CartesianIndices(sz[1:pdim-1])
Rpost = CartesianIndices(sz[pdim+1:N])

fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre, Rpost)
end
R1 = CartesianIndices(size(x)[1:p.region[1]-1])
R2 = CartesianIndices(size(x)[p.region[1]+1:p.region[2]-1])
R3 = CartesianIndices(size(x)[p.region[2]+1:end])
y_tmp = similar(y, axes(y)[p.region])
rows, cols = size(x)[p.region]
# Introduce function barrier here since the variables used in the loop ranges aren't inferred. This
# is partly because the region field of the plan is abstractly typed but even if that wasn't the case,
# it might be a bit tricky to construct the Rxs in an inferred way.
_mul_loop!(y_tmp, y, x, p, R1, R2, R3, rows, cols)
return y

return out
end

function _mul_loop!(
y_tmp::AbstractArray,
y::AbstractArray,
x::AbstractArray,
p::FFTAPlan,
R1::CartesianIndices,
R2::CartesianIndices,
R3::CartesianIndices,
rows::Int,
cols::Int
)
for I3 in R3, I2 in R2, I1 in R1
for k in 1:cols
@views fft!(y_tmp[:,k], x[I1,:,I2,k,I3], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
end
function fft_along_dim!(
A::AbstractArray,
ibuf::Vector{T}, obuf::Vector{T},
cg::CallGraph{T}, d::Direction,
Rpre::CartesianIndices{M}, Rpost::CartesianIndices
) where {T <: Complex{<:AbstractFloat}, M}

t = cg[1].type
dim = M + 1
cols = eachindex(axes(A, dim), ibuf, obuf)

for k in 1:rows
@views fft!(y[I1,k,I2,:,I3], y_tmp[k,:], 1, 1, p.dir, p.callgraph[2][1].type, p.callgraph[2], 1)
for Ipost in Rpost, Ipre in Rpre
for j in cols
ibuf[j] = A[Ipre, j, Ipost]
end
fft!(obuf, ibuf, 1, 1, d, t, cg, 1)
for j in cols
A[Ipre, j, Ipost] = obuf[j]
end
end
end


## *
### Complex
function Base.:*(p::FFTAPlan_cx{T,1}, x::AbstractVector{T}) where {T<:Complex}
Expand Down
37 changes: 26 additions & 11 deletions test/argument_checking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ using Test, FFTA
using LinearAlgebra: LinearAlgebra

@testset "Only 1D and 2D FFTs" begin
xr = zeros(2, 2)
xr = zeros(2, 2, 2)
xc = complex(xr)
@test_throws ArgumentError("only supports 1D and 2D FFTs") plan_fft(xc, 1:3)
@test_throws ArgumentError("only supports 1D and 2D FFTs") plan_bfft(xc, 1:3)

@test_throws ArgumentError("only supports 1D and 2D FFTs") plan_rfft(xr, 1:3)
@test_throws ArgumentError("only supports 1D and 2D FFTs") plan_brfft(xc, 2, 1:3)
end
Expand Down Expand Up @@ -59,26 +58,42 @@ end
@test_throws DimensionMismatch plan_brfft(yr2, size(xr2, 1)) * yr2p
end
end
@testset "3D array" begin
xc3 = randn(ComplexF64, 3, 3, 3)
yc3 = randn(ComplexF64, 5, 5, 5)
pxc3 = plan_fft(xc3)
@test_throws DimensionMismatch pxc3 * yc3
invalid_p = plan_fft(randn(ComplexF64, ntuple(i -> 3, 5)), 3:5)
xc4 = randn(ComplexF64, (1, ntuple(i -> 5, 3)...))

### plan region out of bounds

# all same dims
@test_throws DimensionMismatch("Plan region is outside array dimensions.") invalid_p * xc3
# dim(p) < dim(out) = dim(in)
@test_throws DimensionMismatch("Plan region is outside array dimensions.") LinearAlgebra.mul!(xc4, invalid_p, xc4)
end
end

@testset "mismatch between input and output arrays" begin
@testset "1D plan 1D array" begin
x1 = complex(randn(3))
x1 = randn(ComplexF64, 3)
y1 = similar(x1, length(x1) + 1)

@test_throws DimensionMismatch LinearAlgebra.mul!(y1, plan_fft(x1), x1)
end

@testset "2D array" begin
x2 = complex.(randn(3, 3), randn(3, 3))
y2 = similar(x2, size(x2, 1) + 1, size(x2, 2) + 1)
@testset "$(N)D array" for N in 2:4
xN = randn(ComplexF64, ntuple(i -> 3, N))
yN = similar(xN, size(xN) .+ 1)

@testset "1D plan, region=$(region)" for region in [1, 2]
@test_throws DimensionMismatch LinearAlgebra.mul!(y2, plan_fft(x2, region), x2)
@testset "1D plan, region=$(region)" for region in 1:N
@test_throws DimensionMismatch LinearAlgebra.mul!(yN, plan_fft(xN, region), xN)
end

@testset "2D plan" begin
@test_throws DimensionMismatch LinearAlgebra.mul!(y2, plan_fft(x2), x2)
@testset "$(N)D plan" begin
@test_throws DimensionMismatch LinearAlgebra.mul!(yN, plan_fft(xN), xN)
@test_throws DimensionMismatch LinearAlgebra.mul!(yN, plan_fft(xN, 1:N-1), xN)
end
end
end
27 changes: 27 additions & 0 deletions test/ndim/minimal_complex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using FFTA, Test

@testset "Basic ND checks" begin
for sz in ((3, 5, 7), (4, 14, 9), (103, 5, 13), (26, 33, 35, 4), ntuple(i -> 3, 5))
x = ones(sz)
@test fft(x) ≈ setindex!(zeros(sz), prod(sz), 1)
end

y = zeros((3, 3, 3))
y[2, 2, 2] = 1
w1 = -0.5 - sqrt(3)im / 2
w2 = conj(w1)
y_ref = ComplexF64[
1 w1 w2;
w1 w2 1;
w2 1 w1
;;;
w1 w2 1;
w2 1 w1;
1 w1 w2
;;;
w2 1 w1;
1 w1 w2;
w1 w2 1
]
@test isapprox(fft(y), y_ref)
end
2 changes: 1 addition & 1 deletion test/qa/explicit_imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import ExplicitImports
# No non-public accesses in FFTA (ie. no `... MyPkg._non_public_internal_func(...)`)
# AbstractFFTs requires subtyping of `Plan` but it is not public
# This is an upstream bug in AbstractFFTs.jl
@test ExplicitImports.check_all_qualified_accesses_are_public(FFTA; ignore = (:Plan, :require_one_based_indexing, :Fix1)) === nothing
@test ExplicitImports.check_all_qualified_accesses_are_public(FFTA; ignore = (:Plan, :require_one_based_indexing, :Fix1, :Cartesian)) === nothing

# No self-qualified accesses in FFTA (ie. no `... FFTA.func(...)`)
@test ExplicitImports.check_no_self_qualified_accesses(FFTA) === nothing
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ Random.seed!(1)
end
end
end
@testset verbose = true "N-D" begin
@testset verbose = true "Minimal tests" begin
include("ndim/minimal_complex.jl")
end
end
@testset verbose = true "Custom element types" begin
include("custom_element_types.jl")
end
Expand Down
Loading