Skip to content

Commit 16b6f43

Browse files
author
ioannisPApapadopoulos
committed
N-dimensional plan
1 parent 04556c5 commit 16b6f43

File tree

2 files changed

+102
-18
lines changed

2 files changed

+102
-18
lines changed

src/arrays.jl

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
struct ArrayPlan{T, F<:FTPlan{<:T}, Szs<:Tuple, Dims<:Tuple{<:Int}} <: Plan{T}
2-
F::F
1+
struct ArrayPlan{T, FF<:FTPlan{<:T}, Szs<:Tuple, Dims<:Tuple{<:Int}} <: Plan{T}
2+
F::FF
33
szs::Szs
44
dims::Dims
55
end
6+
size(P::ArrayPlan) = P.szs
7+
size(P::ArrayPlan, k::Int) = P.szs[k]
68
size(P::ArrayPlan, k...) = P.szs[[k...]]
79

810
function ArrayPlan(F::FTPlan{<:T}, c::AbstractArray{T}, dims::Tuple{<:Int}=(1,)) where T
@@ -18,6 +20,7 @@ function inv_perm(d::Vector{<:Int})
1820
end
1921
return inv_d
2022
end
23+
inv_perm(d::Tuple) = inv_perm([d...])
2124

2225
function *(P::ArrayPlan, f::AbstractArray)
2326
F, dims, szs = P.F, P.dims, P.szs
@@ -45,4 +48,45 @@ function \(P::ArrayPlan, f::AbstractArray)
4548
fr = reshape(fp, size(fp,1), prod(size(fp)[2:end]))
4649

4750
permutedims(reshape(F\fr, size(fp)...), inv_perm(perm))
51+
end
52+
53+
struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Dims<:Tuple} <: Plan{T}
54+
F::FF
55+
dims::Dims
56+
function NDimsPlan(F, dims)
57+
if length(Set(size(F, dims...))) > 1
58+
error("Different size in dims axes not yet implemented in N-dimensional transform.")
59+
end
60+
new{eltype(F), typeof(F), typeof(dims)}(F, dims)
61+
end
62+
end
63+
64+
size(P::NDimsPlan) = size(P.F)
65+
size(P::NDimsPlan, k::Int) = size(P.F, k)
66+
size(P::NDimsPlan, k...) = size(P.F, k...)
67+
68+
function *(P::NDimsPlan, f::AbstractArray)
69+
F, dims = P.F, P.dims
70+
@assert size(F) == size(f)
71+
g = copy(f)
72+
t = 1:ndims(g)
73+
for d in dims
74+
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], length(t))
75+
gp = permutedims(g, perm)
76+
g = permutedims(F*gp, inv_perm(perm))
77+
end
78+
return g
79+
end
80+
81+
function \(P::NDimsPlan, f::AbstractArray)
82+
F, dims = P.F, P.dims
83+
@assert size(F) == size(f)
84+
g = copy(f)
85+
t = 1:ndims(g)
86+
for d in dims
87+
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], length(t))
88+
gp = permutedims(g, perm)
89+
g = permutedims(F\gp, inv_perm(perm))
90+
end
91+
return g
4892
end

test/arraystests.jl

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,63 @@
11
using FastTransforms, Test
2-
import FastTransforms: ArrayPlan
2+
import FastTransforms: ArrayPlan, NDimsPlan
33

44
@testset "Array transform" begin
5-
c = randn(5,20,10)
6-
F = plan_cheb2leg(c)
7-
FT = ArrayPlan(F, c)
5+
@testset "ArrayPlan" begin
6+
c = randn(5,20,10)
7+
F = plan_cheb2leg(c)
8+
FT = ArrayPlan(F, c)
89

9-
f = similar(c);
10-
for k in axes(c,3)
11-
f[:,:,k] = (F*c[:,:,k])
10+
@test size(FT) == size(c)
11+
@test size(FT,1) == size(c,1)
12+
@test size(FT,1,2) == (size(c,1), size(c,2))
13+
14+
f = similar(c);
15+
for k in axes(c,3)
16+
f[:,:,k] = (F*c[:,:,k])
17+
end
18+
@test f FT*c
19+
@test c FT\f
20+
21+
F = plan_cheb2leg(Vector{Float64}(axes(c,2)))
22+
FT = ArrayPlan(F, c, (2,))
23+
for k in axes(c,3)
24+
f[:,:,k] = (F*c[:,:,k]')'
25+
end
26+
@test f FT*c
27+
@test c FT\f
1228
end
13-
@test f FT*c
14-
@test c FT\f
1529

16-
F = plan_cheb2leg(Vector{Float64}(axes(c,2)))
17-
FT = ArrayPlan(F, c, (2,))
18-
for k in axes(c,3)
19-
f[:,:,k] = (F*c[:,:,k]')'
30+
@testset "NDimsPlan" begin
31+
c = randn(20,10,20)
32+
@test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), (1,2))
33+
34+
c = randn(20,20,5);
35+
F = plan_cheb2leg(c)
36+
FT = ArrayPlan(F, c)
37+
P = NDimsPlan(FT, (1,2))
38+
39+
@test size(P) == size(c)
40+
@test size(P,1) == size(c,1)
41+
@test size(P,1,2) == (size(c,1), size(c,2))
42+
43+
44+
f = similar(c);
45+
for k in axes(f,3)
46+
f[:,:,k] = (F*(F*c[:,:,k])')'
47+
end
48+
@test f P*c
49+
@test c P\f
50+
51+
c = randn(10,5,10,60)
52+
F = plan_cheb2leg(c)
53+
P = NDimsPlan(ArrayPlan(F, c), (1,3))
54+
f = similar(c)
55+
for i in axes(f,2), j in axes(f,4)
56+
f[:,i,:,j] = (F*(F*c[:,i,:,j])')'
57+
end
58+
@test f P*c
59+
@test c P\f
2060
end
21-
@test f FT*c
22-
@test c FT\f
23-
end
61+
end
62+
63+

0 commit comments

Comments
 (0)