Skip to content

Commit 198d220

Browse files
authored
Support arbitrary tensors in Chebyshev U (#232)
* Support arbitrary tensors * tests pass
1 parent db87313 commit 198d220

File tree

2 files changed

+74
-71
lines changed

2 files changed

+74
-71
lines changed

src/chebyshevtransform.jl

Lines changed: 31 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -423,19 +423,13 @@ plan_chebyshevutransform!(x::AbstractArray, dims...; kws...) = plan_chebyshevutr
423423
plan_chebyshevutransform(x::AbstractArray, dims...; kws...) = plan_chebyshevutransform(x, Val(1), dims...; kws...)
424424

425425

426-
@inline function _chebu1_prescale!(d::Number, x::AbstractVecOrMat{T}) where T
427-
m,n = size(x,1),size(x,2)
428-
if d == 1
429-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
430-
x[k,j] *= sinpi(one(T)/(2m) + (k-one(T))/m)/m
431-
end
432-
else
433-
@assert d == 2
434-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
435-
x[k,j] *= sinpi(one(T)/(2n) + (j-one(T))/n)/n
436-
end
437-
end
438-
x
426+
_permfirst(d, N) = [d; 1:d-1; d+1:N]
427+
428+
@inline function _chebu1_prescale!(d::Number, X::AbstractArray{T,N}) where {T,N}
429+
= PermutedDimsArray(X, _permfirst(d, N))
430+
m = size(X̃,1)
431+
X̃ .= (sinpi.(one(T)/(2m) .+ ((1:m) .- one(T))/m) ./ m) .*
432+
X
439433
end
440434

441435
@inline function _chebu1_prescale!(d, y::AbstractArray)
@@ -445,19 +439,11 @@ end
445439
y
446440
end
447441

448-
@inline function _chebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
449-
m,n = size(x,1),size(x,2)
450-
if d == 1
451-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
452-
x[k,j] /= sinpi(one(T)/(2m) + (k-one(T))/m)/m
453-
end
454-
else
455-
@assert d == 2
456-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
457-
x[k,j] /= sinpi(one(T)/(2n) + (j-one(T))/n)/n
458-
end
459-
end
460-
x
442+
@inline function _chebu1_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
443+
= PermutedDimsArray(X, _permfirst(d, N))
444+
m = size(X̃,1)
445+
X̃ .=./ (sinpi.(one(T)/(2m) .+ ((1:m) .- one(T))/m) ./ m)
446+
X
461447
end
462448

463449
@inline function _chebu1_postscale!(d, y::AbstractArray)
@@ -485,21 +471,13 @@ function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,1,K,false}, x::A
485471
y
486472
end
487473

488-
@inline function _chebu2_prescale!(d::Number, x::AbstractVecOrMat{T}) where T
489-
m,n = size(x,1),size(x,2)
490-
if d == 1
491-
c = one(T)/ (m+1)
492-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
493-
x[k,j] *= sinpi(k*c)
494-
end
495-
else
496-
@assert d == 2
497-
c = one(T)/ (n+1)
498-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
499-
x[k,j] *= sinpi(j*c)
500-
end
501-
end
502-
x
474+
475+
@inline function _chebu2_prescale!(d::Number, X::AbstractArray{T,N}) where {T,N}
476+
= PermutedDimsArray(X, _permfirst(d, N))
477+
m = size(X̃,1)
478+
c = one(T)/ (m+1)
479+
X̃ .= sinpi.((1:m) .* c) .*
480+
X
503481
end
504482

505483
@inline function _chebu2_prescale!(d, y::AbstractArray)
@@ -510,21 +488,12 @@ end
510488
end
511489

512490

513-
@inline function _chebu2_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
514-
m,n = size(x,1),size(x,2)
515-
if d == 1
516-
c = one(T)/ (m+1)
517-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
518-
x[k,j] /= sinpi(k*c)
519-
end
520-
else
521-
@assert d == 2
522-
c = one(T)/ (n+1)
523-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
524-
x[k,j] /= sinpi(j*c)
525-
end
526-
end
527-
x
491+
@inline function _chebu2_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
492+
= PermutedDimsArray(X, _permfirst(d, N))
493+
m = size(X̃,1)
494+
c = one(T)/ (m+1)
495+
X̃ .=./ sinpi.((1:m) .* c)
496+
X
528497
end
529498

530499
@inline function _chebu2_postscale!(d, y::AbstractArray)
@@ -618,21 +587,14 @@ inv(P::IChebyshevUTransformPlan{T,2}) where {T} = ChebyshevUTransformPlan{T,2}(P
618587
inv(P::ChebyshevUTransformPlan{T,1}) where {T} = IChebyshevUTransformPlan{T,1}(inv(P.plan).p)
619588
inv(P::IChebyshevUTransformPlan{T,1}) where {T} = ChebyshevUTransformPlan{T,1}(inv(P.plan).p)
620589

621-
@inline function _ichebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
622-
m,n = size(x,1),size(x,2)
623-
if d == 1
624-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
625-
x[k,j] /= 2sinpi(one(T)/(2m) + (k-one(T))/m)
626-
end
627-
else
628-
@assert d == 2
629-
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
630-
x[k,j] /= 2sinpi(one(T)/(2n) + (j-one(T))/n)
631-
end
632-
end
633-
x
590+
@inline function _ichebu1_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
591+
= PermutedDimsArray(X, _permfirst(d, N))
592+
m = size(X̃,1)
593+
X̃ .=./ (2 .* sinpi.(one(T)/(2m) .+ ((1:m) .- one(T))/m))
594+
X
634595
end
635596

597+
636598
@inline function _ichebu1_postscale!(d, y::AbstractArray)
637599
for k in d
638600
_ichebu1_postscale!(k, y)

test/chebyshevtests.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ using FastTransforms, Test
164164
gcopy = copy(g)
165165
P = @inferred(plan_chebyshevutransform(f))
166166
@test P*f g
167-
@test f == fcopy
167+
@test f fcopy
168168
@test_throws ArgumentError P * T[1,2]
169169
P = @inferred(plan_chebyshevutransform(f, 1:1))
170170
@test P*f g
171-
@test f == fcopy
171+
@test f fcopy
172172
@test_throws ArgumentError P * T[1,2]
173173

174174
P = @inferred(plan_chebyshevutransform!(f))
@@ -364,6 +364,47 @@ using FastTransforms, Test
364364
@test ichebyshevtransform(chebyshevtransform(X)) X
365365
@test chebyshevtransform(ichebyshevtransform(X)) X
366366
end
367+
368+
@testset "chebyshevutransform" begin
369+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j]) end
370+
@test @inferred(chebyshevutransform(X,1)) @inferred(chebyshevutransform!(copy(X),1))
371+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j]) end
372+
@test chebyshevutransform(X,2) chebyshevutransform!(copy(X),2)
373+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:]) end
374+
@test chebyshevutransform(X,3) chebyshevutransform!(copy(X),3)
375+
376+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j],Val(2)) end
377+
@test @inferred(chebyshevutransform(X,Val(2),1)) @inferred(chebyshevutransform!(copy(X),Val(2),1))
378+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j],Val(2)) end
379+
@test chebyshevutransform(X,Val(2),2) chebyshevutransform!(copy(X),Val(2),2)
380+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:],Val(2)) end
381+
@test chebyshevutransform(X,Val(2),3) chebyshevutransform!(copy(X),Val(2),3)
382+
383+
@test @inferred(chebyshevutransform(X)) @inferred(chebyshevutransform!(copy(X))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,1),2),3)
384+
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)
385+
end
386+
387+
@testset "ichebyshevutransform" begin
388+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j]) end
389+
@test @inferred(ichebyshevutransform(X,1)) @inferred(ichebyshevutransform!(copy(X),1))
390+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j]) end
391+
@test ichebyshevutransform(X,2) ichebyshevutransform!(copy(X),2)
392+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:]) end
393+
@test ichebyshevutransform(X,3) ichebyshevutransform!(copy(X),3)
394+
395+
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j],Val(2)) end
396+
@test @inferred(ichebyshevutransform(X,Val(2),1)) @inferred(ichebyshevutransform!(copy(X),Val(2),1))
397+
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j],Val(2)) end
398+
@test ichebyshevutransform(X,Val(2),2) ichebyshevutransform!(copy(X),Val(2),2)
399+
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:],Val(2)) end
400+
@test ichebyshevutransform(X,Val(2),3) ichebyshevutransform!(copy(X),Val(2),3)
401+
402+
@test @inferred(ichebyshevutransform(X)) @inferred(ichebyshevutransform!(copy(X))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,1),2),3)
403+
@test @inferred(ichebyshevutransform(X,Val(2))) @inferred(ichebyshevutransform!(copy(X),Val(2))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)
404+
405+
@test ichebyshevutransform(chebyshevutransform(X)) X
406+
@test chebyshevutransform(ichebyshevutransform(X)) X
407+
end
367408

368409
X = randn(1,1,1)
369410
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X

0 commit comments

Comments
 (0)