Skip to content

Commit f2cf8e8

Browse files
committed
overload broadcast for RectPoly
1 parent dfaf6b7 commit f2cf8e8

File tree

5 files changed

+55
-19
lines changed

5 files changed

+55
-19
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1919
LazyBandedMatrices = "d7e5e226-e90b-4449-9968-0f923699bf6f"
2020
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
QuasiArrays = "c4ea9172-b204-11e9-377d-29865faadc5c"
22-
RecurrenceRelationshipArrays = "b889d2dc-af3c-4820-88a8-238fa91d3518"
2322
RecurrenceRelationships = "807425ed-42ea-44d6-a357-6771516d7b2c"
2423
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2524
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -35,12 +34,13 @@ ContinuumArrays = "0.18"
3534
DomainSets = "0.7"
3635
FastTransforms = "0.17"
3736
FillArrays = "1.0"
38-
HarmonicOrthogonalPolynomials = "0.6"
37+
HarmonicOrthogonalPolynomials = "0.6.3"
3938
InfiniteArrays = "0.15"
4039
InfiniteLinearAlgebra = "0.9"
4140
LazyArrays = "2.3.1"
4241
LazyBandedMatrices = "0.11.1"
4342
QuasiArrays = "0.11"
43+
RecurrenceRelationships = "0.2"
4444
SpecialFunctions = "1, 2"
4545
StaticArrays = "1"
4646
julia = "1.10"

src/MultivariateOrthogonalPolynomials.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ import Base.Broadcast: Broadcasted, broadcasted, DefaultArrayStyle
1111
import DomainSets: boundary, EuclideanDomain
1212

1313
import QuasiArrays: LazyQuasiMatrix, LazyQuasiArrayStyle, domain
14-
import ContinuumArrays: @simplify, Weight, weight, grid, plotgrid, TransformFactorization, ExpansionLayout, plotvalues, unweighted, plan_transform, checkpoints, transform_ldiv, AbstractBasisLayout, basis_axes, Inclusion, grammatrix, weaklaplacian
14+
import ContinuumArrays: @simplify, Weight, weight, grid, plotgrid, TransformFactorization, ExpansionLayout, plotvalues, unweighted, plan_transform, checkpoints, transform_ldiv, AbstractBasisLayout, basis_axes, Inclusion, grammatrix, weaklaplacian, layout_broadcasted
1515

1616
import ArrayLayouts: MemoryLayout, sublayout, sub_materialize
1717
import BlockArrays: block, blockindex, BlockSlice, viewblock, blockcolsupport, AbstractBlockStyle, BlockStyle
1818
import BlockBandedMatrices: _BandedBlockBandedMatrix, AbstractBandedBlockBandedMatrix, _BandedMatrix, blockbandwidths, subblockbandwidths
1919
import LinearAlgebra: factorize
2020
import LazyArrays: arguments, paddeddata, LazyArrayStyle, LazyLayout, PaddedLayout, applylayout, LazyMatrix, ApplyMatrix
21-
import LazyBandedMatrices: LazyBandedBlockBandedLayout, AbstractBandedBlockBandedLayout, AbstractLazyBandedBlockBandedLayout, _krontrav_axes, DiagTravLayout, invdiagtrav, ApplyBandedBlockBandedLayout
21+
import LazyBandedMatrices: LazyBandedBlockBandedLayout, AbstractBandedBlockBandedLayout, AbstractLazyBandedBlockBandedLayout, _krontrav_axes, DiagTravLayout, invdiagtrav, ApplyBandedBlockBandedLayout, krontrav
2222
import InfiniteArrays: InfiniteCardinal, OneToInf
2323

2424
import ClassicalOrthogonalPolynomials: jacobimatrix, Weighted, orthogonalityweight, HalfWeighted, WeightedBasis, pad, recurrencecoefficients, clenshaw, weightedgrammatrix, Clenshaw
2525
import HarmonicOrthogonalPolynomials: BivariateOrthogonalPolynomial, MultivariateOrthogonalPolynomial, Plan,
2626
PartialDerivative, AngularMomentum, BlockOneTo, BlockRange1, interlace,
27-
MultivariateOPLayout, MAX_PLOT_BLOCKS
27+
MultivariateOPLayout, AbstractMultivariateOPLayout, MAX_PLOT_BLOCKS
2828

2929
export MultivariateOrthogonalPolynomial, BivariateOrthogonalPolynomial,
3030
UnitTriangle, UnitDisk,

src/clenshawkron.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ axes(M::ClenshawKron) = (blockedrange(oneto(∞)),blockedrange(oneto(∞)))
2323
blockbandwidths(M::ClenshawKron) = (size(M.c,1)-1,size(M.c,1)-1)
2424
subblockbandwidths(M::ClenshawKron) = (size(M.c,2)-1,size(M.c,2)-1)
2525

26+
struct ClenshawKronLayout <: AbstractLazyBandedBlockBandedLayout end
27+
MemoryLayout(::Type{<:ClenshawKron}) = ClenshawKronLayout()
28+
2629
function square_getindex(M::ClenshawKron, N::Block{1})
2730
# Consider P(x) = L^1_x \ 𝐞_0
2831
# So that if a(x,y) = P(x)*c*Q(y)' then we have
@@ -47,5 +50,7 @@ getindex(M::ClenshawKron, K::Block{1}, J::Block{1}) = square_getindex(M, max(K,
4750
getindex(M::ClenshawKron, Kk::BlockIndex{1}, Jj::BlockIndex{1}) = M[block(Kk), block(Jj)][blockindex(Kk), blockindex(Jj)]
4851
getindex(M::ClenshawKron, k::Int, j::Int) = M[findblockindex(axes(M,1),k), findblockindex(axes(M,2),j)]
4952

50-
Base.array_summary(io::IO, C::ClenshawKron{T}, inds::Tuple{Vararg{OneToInf{Int}}}) where T =
53+
Base.array_summary(io::IO, C::ClenshawKron{T}, inds) where T =
5154
print(io, Base.dims2string(length.(inds)), " ClenshawKron{$T} with $(size(C.c)) polynomial")
55+
56+
Base.summary(io::IO, C::ClenshawKron) = Base.array_summary(io, C, axes(C))

src/rect.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ end
1818

1919
==(A::KronPolynomial, B::KronPolynomial) = length(A.args) == length(B.args) && all(map(==, A.args, B.args))
2020

21+
22+
struct KronOPLayout{d} <: AbstractMultivariateOPLayout{d} end
23+
MemoryLayout(::Type{<:KronPolynomial{d}}) where d = KronOPLayout{d}()
24+
2125
const RectPolynomial{T, PP} = KronPolynomial{2, T, PP}
2226

2327

@@ -71,7 +75,7 @@ end
7175
function \(P::RectPolynomial, Q::RectPolynomial)
7276
PA,PB = P.args
7377
QA,QB = Q.args
74-
KronTrav(PA\QA, PB\QB)
78+
krontrav(PA\QA, PB\QB)
7579
end
7680

7781
@simplify function *(Ac::QuasiAdjoint{<:Any,<:RectPolynomial}, B::RectPolynomial)
@@ -144,14 +148,14 @@ pad(C::DiagTrav, ::BlockedOneTo{Int,RangeCumsum{Int,OneToInf{Int}}}) = DiagTrav(
144148

145149
QuasiArrays.mul(A::BivariateOrthogonalPolynomial, b::DiagTrav) = ApplyQuasiArray(*, A, b)
146150

147-
function Base.unsafe_getindex(f::Mul{MultivariateOPLayout{2},<:DiagTravLayout{<:PaddedLayout}}, 𝐱::SVector)
151+
function Base.unsafe_getindex(f::Mul{KronOPLayout{2},<:DiagTravLayout{<:PaddedLayout}}, 𝐱::SVector)
148152
P,c = f.A, f.B
149153
A,B = P.args
150154
x,y = 𝐱
151155
clenshaw(vec(clenshaw(paddeddata(c.array), recurrencecoefficients(A)..., x; dims=1)), recurrencecoefficients(B)..., y)
152156
end
153157

154-
Base.@propagate_inbounds function getindex(f::Mul{MultivariateOPLayout{2},<:DiagTravLayout{<:PaddedLayout}}, x::SVector, j...)
158+
Base.@propagate_inbounds function getindex(f::Mul{KronOPLayout{2},<:DiagTravLayout{<:PaddedLayout}}, x::SVector, j...)
155159
@inbounds checkbounds(ApplyQuasiArray(*,f.A,f.B), x, j...)
156160
Base.unsafe_getindex(f, x, j...)
157161
end
@@ -171,3 +175,16 @@ function Base._sum(P::RectPolynomial, dims)
171175
@assert dims == 1
172176
KronTrav(sum.(P.args; dims=1)...)
173177
end
178+
179+
## multiplication
180+
181+
function layout_broadcasted(::Tuple{ExpansionLayout{KronOPLayout{2}},KronOPLayout{2}}, ::typeof(*), a, P)
182+
axes(a,1) == axes(P,1) || throw(DimensionMismatch())
183+
184+
A,B = basis(a).args
185+
T,U = P.args
186+
187+
C = paddeddata(invdiagtrav(coefficients(a)))
188+
189+
P * ClenshawKron(C, (recurrencecoefficients(A), recurrencecoefficients(B)), (jacobimatrix(T), jacobimatrix(U)))
190+
end

test/test_rect.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,33 @@ using Base: oneto
169169

170170
a = (x,y) -> I + x + 2y + 3x^2 +4x*y + 5y^2
171171
𝐚 = expand(P,splat(a))
172-
173-
C = LazyBandedMatrices.paddeddata(LazyBandedMatrices.invdiagtrav(coefficients(𝐚)))
174-
175-
A = ClenshawKron(C, (recurrencecoefficients(T), recurrencecoefficients(U)), (jacobimatrix(T), jacobimatrix(U)))
176172

173+
@testset "ClenshawKron" begin
174+
C = LazyBandedMatrices.paddeddata(LazyBandedMatrices.invdiagtrav(coefficients(𝐚)))
177175

178-
= a(X,Y)
179-
for (k,j) in ((Block.(oneto(5)),Block.(oneto(5))), Block.(oneto(5)),Block.(oneto(6)), (Block(2), Block(3)), (4,5),
180-
(Block(2)[2], Block(3)[3]), (Block(2)[2], Block(3)))
181-
@test A[k,j] Ã[k,j]
176+
A = ClenshawKron(C, (recurrencecoefficients(T), recurrencecoefficients(U)), (jacobimatrix(T), jacobimatrix(U)))
177+
178+
@test copy(A) A
179+
@test size(A) == size(X)
180+
@test summary(A) == "ℵ₀×ℵ₀ ClenshawKron{Float64} with (3, 3) polynomial"
181+
182+
= a(X,Y)
183+
for (k,j) in ((Block.(oneto(5)),Block.(oneto(5))), Block.(oneto(5)),Block.(oneto(6)), (Block(2), Block(3)), (4,5),
184+
(Block(2)[2], Block(3)[3]), (Block(2)[2], Block(3)))
185+
@test A[k,j] Ã[k,j]
186+
end
187+
188+
@test A[Block(1,2)] Ã[Block(1,2)]
189+
@test A[Block(1,2)][1,2] Ã[Block(1,2)[1,2]]
182190
end
183191

184-
@test A[Block(1,2)] Ã[Block(1,2)]
185-
@test A[Block(1,2)][1,2] Ã[Block(1,2)[1,2]]
192+
@test P \ (𝐚 .* P) isa ClenshawKron
193+
194+
@test (𝐚 .* 𝐚)[SVector(0.1,0.2)] 𝐚[SVector(0.1,0.2)]^2
195+
196+
𝐛 = expand(RectPolynomial(Legendre(),Ultraspherical(3/2)),splat((x,y) -> cos(x*sin(y))))
197+
@test (𝐛 .* 𝐚)[SVector(0.1,0.2)] 𝐚[SVector(0.1,0.2)]𝐛[SVector(0.1,0.2)]
198+
199+
𝐜 = expand(RectPolynomial(Legendre(),Jacobi(1,0)),splat((x,y) -> cos(x*sin(y))))
186200
end
187201
end

0 commit comments

Comments
 (0)