Skip to content

Commit e0dd9d6

Browse files
author
ioannisPApapadopoulos
committed
fix N-dimensional transform when dims[1] != 1
1 parent 87b1924 commit e0dd9d6

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

src/arrays.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function *(P::ArrayPlan, f::AbstractArray)
3131
perm = [d; setdiff(1:ndims(f), d)]
3232
fp = permutedims(f, perm)
3333

34-
fr = reshape(fp, size(fp,1), prod(size(fp)[2:end]))
34+
fr = reshape(fp, size(fp,1), :)
3535

3636
permutedims(reshape(F*fr, size(fp)...), inv_perm(perm))
3737
end
@@ -45,33 +45,39 @@ function \(P::ArrayPlan, f::AbstractArray)
4545
perm = [d; setdiff(1:ndims(f), d)]
4646
fp = permutedims(f, perm)
4747

48-
fr = reshape(fp, size(fp,1), prod(size(fp)[2:end]))
48+
fr = reshape(fp, size(fp,1), :)
4949

5050
permutedims(reshape(F\fr, size(fp)...), inv_perm(perm))
5151
end
5252

53-
struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Dims<:Tuple} <: Plan{T}
53+
struct NDimsPlan{T, FF<:ArrayPlan{<:T}, Szs<:Tuple, Dims<:Tuple} <: Plan{T}
5454
F::FF
55+
szs::Szs
5556
dims::Dims
56-
function NDimsPlan(F, dims)
57-
if length(Set(size(F, dims...))) > 1
57+
function NDimsPlan(F, szs, dims)
58+
if length(Set(szs[[dims...]])) > 1
5859
error("Different size in dims axes not yet implemented in N-dimensional transform.")
5960
end
60-
new{eltype(F), typeof(F), typeof(dims)}(F, dims)
61+
new{eltype(F), typeof(F), typeof(szs), typeof(dims)}(F, szs, dims)
6162
end
6263
end
6364

64-
size(P::NDimsPlan) = size(P.F)
65-
size(P::NDimsPlan, k::Int) = size(P.F, k)
66-
size(P::NDimsPlan, k...) = size(P.F, k...)
65+
size(P::NDimsPlan) = P.szs
66+
size(P::NDimsPlan, k::Int) = P.szs[k]
67+
size(P::NDimsPlan, k...) = P.szs[[k...]]
68+
69+
function NDimsPlan(F::FTPlan, szs::Tuple, dims::Tuple)
70+
NDimsPlan(ArrayPlan(F, szs, (first(dims),)), szs, dims)
71+
end
6772

6873
function *(P::NDimsPlan, f::AbstractArray)
6974
F, dims = P.F, P.dims
70-
@assert size(F) == size(f)
75+
@assert size(P) == size(f)
7176
g = copy(f)
7277
t = 1:ndims(g)
78+
d1 = dims[1]
7379
for d in dims
74-
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], ndims(g))
80+
perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g))
7581
gp = permutedims(g, perm)
7682
g = permutedims(F*gp, inv_perm(perm))
7783
end
@@ -80,11 +86,12 @@ end
8086

8187
function \(P::NDimsPlan, f::AbstractArray)
8288
F, dims = P.F, P.dims
83-
@assert size(F) == size(f)
89+
@assert size(P) == size(f)
8490
g = copy(f)
8591
t = 1:ndims(g)
92+
d1 = dims[1]
8693
for d in dims
87-
perm = ntuple(k -> k == 1 ? t[d] : k == d ? t[1] : t[k], ndims(g))
94+
perm = ntuple(k -> k == d1 ? t[d] : k == d ? t[d1] : t[k], ndims(g))
8895
gp = permutedims(g, perm)
8996
g = permutedims(F\gp, inv_perm(perm))
9097
end

test/arraystests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,18 @@ import FastTransforms: ArrayPlan, NDimsPlan
2929

3030
@testset "NDimsPlan" begin
3131
c = randn(20,10,20)
32-
@test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), (1,2))
32+
@test_throws ErrorException("Different size in dims axes not yet implemented in N-dimensional transform.") NDimsPlan(ArrayPlan(plan_cheb2leg(c), c), size(c), (1,2))
3333

3434
c = randn(5,20)
3535
F = plan_cheb2leg(c)
3636
FT = ArrayPlan(F, c)
37-
P = NDimsPlan(FT, (1,))
37+
P = NDimsPlan(F, size(c), (1,))
3838
@test F*c FT*c P*c
3939

4040
c = randn(20,20,5);
4141
F = plan_cheb2leg(c)
4242
FT = ArrayPlan(F, c)
43-
P = NDimsPlan(FT, (1,2))
43+
P = NDimsPlan(FT, size(c), (1,2))
4444

4545
@test size(P) == size(c)
4646
@test size(P,1) == size(c,1)
@@ -53,12 +53,12 @@ import FastTransforms: ArrayPlan, NDimsPlan
5353
@test f P*c
5454
@test c P\f
5555

56-
c = randn(10,5,10,60)
57-
F = plan_cheb2leg(c)
58-
P = NDimsPlan(ArrayPlan(F, c), (1,3))
56+
c = randn(5,10,10,60)
57+
F = plan_cheb2leg(randn(10))
58+
P = NDimsPlan(F, size(c), (2,3))
5959
f = similar(c)
60-
for i in axes(f,2), j in axes(f,4)
61-
f[:,i,:,j] = (F*(F*c[:,i,:,j])')'
60+
for i in axes(f,1), j in axes(f,4)
61+
f[i,:,:,j] = (F*(F*c[i,:,:,j])')'
6262
end
6363
@test f P*c
6464
@test c P\f

0 commit comments

Comments
 (0)