Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/FastTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ for f in (:jac2jac,
@eval $f(x::AbstractArray, y...; z...) = $lib_f(x, y...; z...)
end

include("arrays.jl")
# following use Toeplitz-Hankel to avoid expensive plans
# for f in (:leg2cheb, :cheb2leg, :ultra2ultra)
# th_f = Symbol("th_", f)
Expand Down
92 changes: 92 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
struct ArrayPlan{T, FF<:FTPlan{<:T}, Szs<:Tuple, Dims<:Tuple{<:Int}} <: Plan{T}
F::FF
szs::Szs
dims::Dims
end
size(P::ArrayPlan) = P.szs
size(P::ArrayPlan, k::Int) = P.szs[k]
size(P::ArrayPlan, k...) = P.szs[[k...]]

function ArrayPlan(F::FTPlan{<:T}, c::AbstractArray{T}, dims::Tuple{<:Int}=(1,)) where T
szs = size(c)
@assert F.n == szs[dims[1]]
ArrayPlan(F, size(c), dims)
end

function inv_perm(d::Vector{<:Int})
inv_d = Vector{Int}(undef, length(d))
for (i, val) in enumerate(d)
inv_d[val] = i
end
return inv_d
end
inv_perm(d::Tuple) = inv_perm([d...])

function *(P::ArrayPlan, f::AbstractArray)
F, dims, szs = P.F, P.dims, P.szs
@assert length(dims) == 1
@assert szs == size(f)
d = first(dims)

perm = [d; setdiff(1:ndims(f), d)]
fp = permutedims(f, perm)

fr = reshape(fp, size(fp,1), prod(size(fp)[2:end]))

permutedims(reshape(F*fr, size(fp)...), inv_perm(perm))
end

function \(P::ArrayPlan, f::AbstractArray)
F, dims, szs = P.F, P.dims, P.szs
@assert length(dims) == 1
@assert szs == size(f)
d = first(dims)

perm = [d; setdiff(1:ndims(f), d)]
fp = permutedims(f, perm)

fr = reshape(fp, size(fp,1), prod(size(fp)[2:end]))

permutedims(reshape(F\fr, size(fp)...), inv_perm(perm))
end

struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Dims<:Tuple} <: Plan{T}
F::FF
dims::Dims
function NDimsPlan(F, dims)
if length(Set(size(F, dims...))) > 1
error("Different size in dims axes not yet implemented in N-dimensional transform.")
end
new{eltype(F), typeof(F), typeof(dims)}(F, dims)
end
end

size(P::NDimsPlan) = size(P.F)
size(P::NDimsPlan, k::Int) = size(P.F, k)
size(P::NDimsPlan, k...) = size(P.F, k...)

function *(P::NDimsPlan, f::AbstractArray)
F, dims = P.F, P.dims
@assert size(F) == size(f)
g = copy(f)
t = 1:ndims(g)
for d in dims
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], length(t))
gp = permutedims(g, perm)
g = permutedims(F*gp, inv_perm(perm))
end
return g
end

function \(P::NDimsPlan, f::AbstractArray)
F, dims = P.F, P.dims
@assert size(F) == size(f)
g = copy(f)
t = 1:ndims(g)
for d in dims
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], length(t))
gp = permutedims(g, perm)
g = permutedims(F\gp, inv_perm(perm))
end
return g
end
63 changes: 63 additions & 0 deletions test/arraystests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using FastTransforms, Test
import FastTransforms: ArrayPlan, NDimsPlan

@testset "Array transform" begin
@testset "ArrayPlan" begin
c = randn(5,20,10)
F = plan_cheb2leg(c)
FT = ArrayPlan(F, c)

@test size(FT) == size(c)
@test size(FT,1) == size(c,1)
@test size(FT,1,2) == (size(c,1), size(c,2))

f = similar(c);
for k in axes(c,3)
f[:,:,k] = (F*c[:,:,k])
end
@test f ≈ FT*c
@test c ≈ FT\f

F = plan_cheb2leg(Vector{Float64}(axes(c,2)))
FT = ArrayPlan(F, c, (2,))
for k in axes(c,3)
f[:,:,k] = (F*c[:,:,k]')'
end
@test f ≈ FT*c
@test c ≈ FT\f
end

@testset "NDimsPlan" begin
c = randn(20,10,20)
@test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), (1,2))

c = randn(20,20,5);
F = plan_cheb2leg(c)
FT = ArrayPlan(F, c)
P = NDimsPlan(FT, (1,2))

@test size(P) == size(c)
@test size(P,1) == size(c,1)
@test size(P,1,2) == (size(c,1), size(c,2))


f = similar(c);
for k in axes(f,3)
f[:,:,k] = (F*(F*c[:,:,k])')'
end
@test f ≈ P*c
@test c ≈ P\f

c = randn(10,5,10,60)
F = plan_cheb2leg(c)
P = NDimsPlan(ArrayPlan(F, c), (1,3))
f = similar(c)
for i in axes(f,2), j in axes(f,4)
f[:,i,:,j] = (F*(F*c[:,i,:,j])')'
end
@test f ≈ P*c
@test c ≈ P\f
end
end


1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ include("clenshawtests.jl")
include("toeplitzplanstests.jl")
include("toeplitzhankeltests.jl")
include("symmetrictoeplitzplushankeltests.jl")
include("arraystests.jl")
Loading