diff --git a/Project.toml b/Project.toml index a9269fe0..6d10b624 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FastTransforms" uuid = "057dd010-8810-581a-b7be-e3fc3b93f78c" -version = "0.16.4" +version = "0.16.5" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/FastTransforms.jl b/src/FastTransforms.jl index 82c34990..fe23f46e 100644 --- a/src/FastTransforms.jl +++ b/src/FastTransforms.jl @@ -128,6 +128,7 @@ for f in (:jac2jac, @eval $f(x::AbstractArray, y...; z...) = $lib_f(x, y...; z...) end +include("arrays.jl") # following use Toeplitz-Hankel to avoid expensive plans # for f in (:leg2cheb, :cheb2leg, :ultra2ultra) # th_f = Symbol("th_", f) diff --git a/src/arrays.jl b/src/arrays.jl new file mode 100644 index 00000000..5472e736 --- /dev/null +++ b/src/arrays.jl @@ -0,0 +1,86 @@ +struct ArrayPlan{T, FF<:FTPlan{<:T}, Szs<:Tuple, Dims<:Tuple{<:Int}} <: Plan{T} + F::FF + szs::Szs + dims::Dims +end +size(P::ArrayPlan) = P.szs + +function ArrayPlan(F::FTPlan{<:T}, c::AbstractArray{T}, dims::Tuple{<:Int}=(1,)) where T + szs = size(c) + @assert F.n == szs[dims[1]] + ArrayPlan(F, size(c), dims) +end + +function *(P::ArrayPlan, f::AbstractArray) + F, dims, szs = P.F, P.dims, P.szs + @assert length(dims) == 1 + @assert szs == size(f) + d = first(dims) + + perm = (d, ntuple(i-> i + (i >= d), ndims(f) -1)...) + fp = permutedims(f, perm) + + fr = reshape(fp, size(fp,1), :) + + permutedims(reshape(F*fr, size(fp)...), invperm(perm)) +end + +function \(P::ArrayPlan, f::AbstractArray) + F, dims, szs = P.F, P.dims, P.szs + @assert length(dims) == 1 + @assert szs == size(f) + d = first(dims) + + perm = (d, ntuple(i-> i + (i >= d), ndims(f) -1)...) + fp = permutedims(f, perm) + + fr = reshape(fp, size(fp,1), :) + + permutedims(reshape(F\fr, size(fp)...), invperm(perm)) +end + +struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Szs<:Tuple, Dims<:Tuple} <: Plan{T} + F::FF + szs::Szs + dims::Dims + function NDimsPlan(F, szs, dims) + if length(Set(szs[[dims...]])) > 1 + error("Different size in dims axes not yet implemented in N-dimensional transform.") + end + new{eltype(F), typeof(F), typeof(szs), typeof(dims)}(F, szs, dims) + end +end + +size(P::NDimsPlan) = P.szs + +function NDimsPlan(F::FTPlan, szs::Tuple, dims::Tuple) + NDimsPlan(ArrayPlan(F, szs, (first(dims),)), szs, dims) +end + +function *(P::NDimsPlan, f::AbstractArray) + F, dims = P.F, P.dims + @assert size(P) == size(f) + g = copy(f) + t = 1:ndims(g) + d1 = dims[1] + for d in dims + perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g)) + gp = permutedims(g, perm) + g = permutedims(F*gp, invperm(perm)) + end + return g +end + +function \(P::NDimsPlan, f::AbstractArray) + F, dims = P.F, P.dims + @assert size(P) == size(f) + g = copy(f) + t = 1:ndims(g) + d1 = dims[1] + for d in dims + perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g)) + gp = permutedims(g, perm) + g = permutedims(F\gp, invperm(perm)) + end + return g +end \ No newline at end of file diff --git a/test/arraystests.jl b/test/arraystests.jl new file mode 100644 index 00000000..55167a90 --- /dev/null +++ b/test/arraystests.jl @@ -0,0 +1,64 @@ +using FastTransforms, Test +import FastTransforms: ArrayPlan, NDimsPlan + +@testset "Array transform" begin + @testset "ArrayPlan" begin + c = randn(5,20,10) + F = plan_cheb2leg(c) + FT = ArrayPlan(F, c) + + @test size(FT) == size(c) + + f = similar(c); + for k in axes(c,3) + f[:,:,k] = (F*c[:,:,k]) + end + @test f ≈ FT*c + @test c ≈ FT\f + + F = plan_cheb2leg(Vector{Float64}(axes(c,2))) + FT = ArrayPlan(F, c, (2,)) + for k in axes(c,3) + f[:,:,k] = (F*c[:,:,k]')' + end + @test f ≈ FT*c + @test c ≈ FT\f + end + + @testset "NDimsPlan" begin + c = randn(20,10,20) + @test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), size(c), (1,2)) + + c = randn(5,20) + F = plan_cheb2leg(c) + FT = ArrayPlan(F, c) + P = NDimsPlan(F, size(c), (1,)) + @test F*c ≈ FT*c ≈ P*c + + c = randn(20,20,5); + F = plan_cheb2leg(c) + FT = ArrayPlan(F, c) + P = NDimsPlan(FT, size(c), (1,2)) + + @test size(P) == size(c) + + f = similar(c); + for k in axes(f,3) + f[:,:,k] = (F*(F*c[:,:,k])')' + end + @test f ≈ P*c + @test c ≈ P\f + + c = randn(5,10,10,60) + F = plan_cheb2leg(randn(10)) + P = NDimsPlan(F, size(c), (2,3)) + f = similar(c) + for i in axes(f,1), j in axes(f,4) + f[i,:,:,j] = (F*(F*c[i,:,:,j])')' + end + @test f ≈ P*c + @test c ≈ P\f + end +end + + diff --git a/test/runtests.jl b/test/runtests.jl index de16f36b..a4881c3b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,3 +12,4 @@ include("clenshawtests.jl") include("toeplitzplanstests.jl") include("toeplitzhankeltests.jl") include("symmetrictoeplitzplushankeltests.jl") +include("arraystests.jl") \ No newline at end of file