Skip to content

Commit 260b154

Browse files
authored
DiagTrav coefficients should be treated as an ExpansionLayout (#158)
1 parent 1ebc7f1 commit 260b154

File tree

4 files changed

+50
-3
lines changed

4 files changed

+50
-3
lines changed

src/MultivariateOrthogonalPolynomials.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import Base.Broadcast: Broadcasted, broadcasted, DefaultArrayStyle
1111
import DomainSets: boundary
1212

1313
import QuasiArrays: LazyQuasiMatrix, LazyQuasiArrayStyle, domain
14-
import ContinuumArrays: @simplify, Weight, weight, grid, plotgrid, TransformFactorization, ExpansionLayout, plotvalues, unweighted, plan_grid_transform, checkpoints, transform_ldiv
14+
import ContinuumArrays: @simplify, Weight, weight, grid, plotgrid, TransformFactorization, ExpansionLayout, plotvalues, unweighted, plan_grid_transform, checkpoints, transform_ldiv, AbstractBasisLayout, basis_axes
1515

1616
import ArrayLayouts: MemoryLayout, sublayout, sub_materialize
1717
import BlockArrays: block, blockindex, BlockSlice, viewblock, blockcolsupport, AbstractBlockStyle, BlockStyle
1818
import BlockBandedMatrices: _BandedBlockBandedMatrix, AbstractBandedBlockBandedMatrix, _BandedMatrix, blockbandwidths, subblockbandwidths
1919
import LinearAlgebra: factorize
20-
import LazyArrays: arguments, paddeddata, LazyArrayStyle, LazyLayout, PaddedLayout
20+
import LazyArrays: arguments, paddeddata, LazyArrayStyle, LazyLayout, PaddedLayout, applylayout
2121
import LazyBandedMatrices: LazyBandedBlockBandedLayout, AbstractBandedBlockBandedLayout, AbstractLazyBandedBlockBandedLayout, _krontrav_axes, DiagTravLayout
2222
import InfiniteArrays: InfiniteCardinal, OneToInf
2323

src/rect.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ function plan_grid_transform(P::KronPolynomial{d,<:Any,<:Fill}, B::Tuple{Block{1
7777
SVector.(x̃, x̃'), ApplyPlan(DiagTrav, F)
7878
end
7979

80+
applylayout(::Type{typeof(*)}, ::Lay, ::DiagTravLayout) where Lay <: AbstractBasisLayout = ExpansionLayout{Lay}()
81+
ContinuumArrays._mul_plotgrid(::Tuple{Any,DiagTravLayout{<:PaddedLayout}}, (P,c)) = plotgrid(P, maximum(blockcolsupport(c)))
82+
8083
pad(C::DiagTrav, ::BlockedUnitRange{RangeCumsum{Int,OneToInf{Int}}}) = DiagTrav(pad(C.array, ∞, ∞))
8184

8285
QuasiArrays.mul(A::BivariateOrthogonalPolynomial, b::DiagTrav) = ApplyQuasiArray(*, A, b)

src/triangle.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ show(io::IO, P::JacobiTriangle) = summary(io, P)
2727
summary(io::IO, P::JacobiTriangle) = print(io, "JacobiTriangle($(P.a), $(P.b), $(P.c))")
2828

2929

30+
basis_axes(::Inclusion{<:Any,<:UnitTriangle}, v) = JacobiTriangle()
31+
3032
"""
3133
TriangleWeight(a, b, c)
3234
@@ -639,6 +641,13 @@ function tridenormalize!(F̌,a,b,c)
639641
640642
end
641643

644+
function trinormalize!(F̌,a,b,c)
645+
for n = 0:size(F̌,1)-1, k = 0:n
646+
F̌[n-k+1,k+1] /= _ft_trinorm(n,k,a,b,c)
647+
end
648+
649+
end
650+
642651
function trigrid(N::Integer)
643652
M = N
644653
x = [sinpi((2N-2n-1)/(4N))^2 for n in 0:N-1]
@@ -666,4 +675,35 @@ function plan_grid_transform(P::JacobiTriangle, Bs::Tuple{Block{1}}, dims=1:1)
666675
T = eltype(P)
667676
N = Bs[1]
668677
grid(P, N), TriPlan{T}(N, P.a, P.b, P.c)
678+
end
679+
680+
struct TriIPlan{T}
681+
tri2cheb::FastTransforms.FTPlan{T,2,FastTransforms.TRIANGLE}
682+
cheb2grid::FastTransforms.FTPlan{T,2,FastTransforms.TRIANGLESYNTHESIS}
683+
a::T
684+
b::T
685+
c::T
686+
end
687+
688+
TriIPlan{T}(F::AbstractMatrix{T}, a, b, c) where T =
689+
TriIPlan{T}(plan_tri2cheb(F, a, b, c), plan_tri_synthesis(F), a, b, c)
690+
691+
TriIPlan{T}(N::Block{1}, a, b, c) where T = TriIPlan{T}(Matrix{T}(undef, Int(N), Int(N)), a, b, c)
692+
693+
*(T::TriIPlan, F::DiagTrav) = T.cheb2grid*(T.tri2cheb*trinormalize!(Matrix(F.array),T.a,T.b,T.c))
694+
695+
696+
function plotgrid(S::JacobiTriangle{T}, B::Block{1}) where T
697+
N = min(2Int(B), MAX_PLOT_BLOCKS)
698+
grid(S, Block(N)) # double sampling
699+
end
700+
701+
function plotvalues(u::ApplyQuasiVector{T,typeof(*),<:Tuple{JacobiTriangle, AbstractVector}}, x) where T
702+
P,c = u.args
703+
B = findblock(axes(P,2), last(colsupport(c)))
704+
705+
N = min(2Int(B), MAX_PLOT_BLOCKS)
706+
F = TriIPlan{T}(Block(N), P.a, P.b, P.c)
707+
C = F * DiagTrav(c.array[1:N,1:N]) # transform to grid
708+
C
669709
end

test/test_triangle.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MultivariateOrthogonalPolynomials, StaticArrays, BlockArrays, BlockBandedMatrices, ArrayLayouts,
22
QuasiArrays, Test, ClassicalOrthogonalPolynomials, BandedMatrices, FastTransforms, LinearAlgebra
3-
import MultivariateOrthogonalPolynomials: tri_forwardrecurrence, grid, TriangleRecurrenceA, TriangleRecurrenceB, TriangleRecurrenceC, xy_muladd!
3+
import MultivariateOrthogonalPolynomials: tri_forwardrecurrence, grid, TriangleRecurrenceA, TriangleRecurrenceB, TriangleRecurrenceC, xy_muladd!, ExpansionLayout
44

55
@testset "Triangle" begin
66
@testset "basics" begin
@@ -52,6 +52,7 @@ import MultivariateOrthogonalPolynomials: tri_forwardrecurrence, grid, TriangleR
5252
𝐱 = SVector(0.1,0.2)
5353
c = PseudoBlockVector([1; Zeros(∞)], (axes(P,2),))
5454
f = P*c
55+
@test MemoryLayout(f) isa ExpansionLayout
5556
@test @inferred(f[𝐱]) == 1.0
5657
c = PseudoBlockVector([1:3; Zeros(∞)], (axes(P,2),))
5758
f = P*c
@@ -115,13 +116,16 @@ import MultivariateOrthogonalPolynomials: tri_forwardrecurrence, grid, TriangleR
115116
N = 20
116117
P_N = P[:,Block.(Base.OneTo(N))]
117118
u = P_N * (P_N \ (exp.(x) .* cos.(y)))
119+
@test MemoryLayout(u) isa ExpansionLayout
118120
@test u[SVector(0.1,0.2)] exp(0.1)*cos(0.2)
119121

120122
P_n = P[:,1:200]
121123
u = P_n * (P_n \ (exp.(x) .* cos.(y)))
124+
@test MemoryLayout(u) isa ExpansionLayout
122125
@test u[SVector(0.1,0.2)] exp(0.1)*cos(0.2)
123126

124127
@time u = P * (P \ (exp.(x) .* cos.(y)))
128+
@test MemoryLayout(u) isa ExpansionLayout
125129
@test u[SVector(0.1,0.2)] exp(0.1)*cos(0.2)
126130
end
127131
end

0 commit comments

Comments
 (0)