Skip to content

Commit 5b55d9d

Browse files
authored
Make * and \ fast for ModalInterlace (#128)
* Make * and \ fast for ModalInterlace, DiskTrav -> ModalTrav * Update ModalInterlace.jl * modaltrav broadcasting * tests pass * Add setindex! * ModalInterlace * and \ tests * increase cov * v0.3
1 parent 7488bdb commit 5b55d9d

File tree

7 files changed

+286
-117
lines changed

7 files changed

+286
-117
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "MultivariateOrthogonalPolynomials"
22
uuid = "4f6956fd-4f93-5457-9149-7bfc4b2ce06d"
3-
version = "0.2.7"
3+
version = "0.3"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -29,7 +29,7 @@ BandedMatrices = "0.16, 0.17"
2929
BlockArrays = "0.16.14"
3030
BlockBandedMatrices = "0.11.5"
3131
ClassicalOrthogonalPolynomials = "0.5.1, 0.6"
32-
ContinuumArrays = "0.10"
32+
ContinuumArrays = "0.10.2"
3333
DomainSets = "0.5"
3434
FastTransforms = "0.13, 0.14"
3535
FillArrays = "0.12, 0.13"

examples/diskheat.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
1-
using MultivariateOrthogonalPolynomials, DifferentialEquations
1+
using MultivariateOrthogonalPolynomials, DifferentialEquations, Plots
2+
pyplot() # pyplot supports disks
23

34
Z = Zernike(1)
45
W = Weighted(Z)
56
xy = axes(W,1)
6-
Δ = Z \ Laplacian(xy) * W
7+
x,y = first.(xy),last.(xy)
8+
Δ = Z \ Laplacian(xy) * W
9+
S = Z \ W
10+
11+
# initial condition is (1-r^2) * exp(-(x-0.1)^2 - (y-0.2)^2)
12+
13+
K = Block.(Base.OneTo(11))
14+
15+
Δₙ = Δ[K,K]
16+
Sₙ = S[K,K]
17+
Zₙ = Z[:,K]
18+
Wₙ = W[:,K]
19+
c₀ = Zₙ \ @.(exp(-(x-0.1)^2 - (y-0.2)^2))
20+
21+
diskheat(c, (Δₙ, Sₙ), t) = Sₙ \ (Δₙ * c)
22+
u = solve(ODEProblem(diskheat, c₀, (0.,1.), (Δₙ, Sₙ)), Tsit5(), reltol=1e-8, abstol=1e-8)
23+
24+
surface(Wₙ * u(1.0))
25+
26+
diskheat(c, (Δₙ, Sₙ), t) = Δₙ * c
27+
u = solve(ODEProblem(ODEFunction(diskheat; jac=(u, (Δₙ, Sₙ), t) -> Δₙ, mass_matrix=Sₙ), c₀, (0.,1.), (Δₙ, Sₙ)), reltol=1e-8, abstol=1e-8)
28+
29+

src/ModalInterlace.jl

Lines changed: 197 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,170 @@
11
"""
2-
ModalInterlace
2+
ModalTrav(A::AbstractMatrix)
3+
4+
takes coefficients as provided by the Zernike polynomial layout of FastTransforms.jl and
5+
makes them accessible sorted such that in each block the m=0 entries are always in first place,
6+
followed by alternating sin and cos terms of increasing |m|.
7+
"""
8+
struct ModalTrav{T, AA<:AbstractMatrix{T}} <: AbstractBlockVector{T}
9+
matrix::AA
10+
function ModalTrav{T, AA}(matrix::AA) where {T,AA<:AbstractMatrix{T}}
11+
m,n = size(matrix)
12+
if isfinite(m)
13+
isfinite(n) && isodd(n) && m == n ÷ 4 + 1 || throw(ArgumentError("size must match"))
14+
end
15+
new{T,AA}(matrix)
16+
end
17+
end
18+
19+
ModalTrav{T}(matrix::AbstractMatrix{T}) where T = ModalTrav{T,typeof(matrix)}(matrix)
20+
ModalTrav(matrix::AbstractMatrix{T}) where T = ModalTrav{T}(matrix)
21+
22+
function ModalTrav{T}(::UndefInitializer, n::Int) where T
23+
N = (isqrt(8n+1)-1) ÷ 2
24+
@assert sum(1:N) == n
25+
m = N ÷ 2 + 1
26+
n = 4(m-1) + 1
27+
ModalTrav(Matrix{T}(undef, m, n))
28+
end
29+
30+
convert(::Type{ModalTrav{T,M}}, v::ModalTrav) where {T,M} = ModalTrav{T,M}(convert(M, v.matrix))
31+
32+
function convert(::Type{ModalTrav{T,M}}, v_in::AbstractVector) where {T,M}
33+
N = (isqrt(8length(v_in)+1)-1) ÷ 2
34+
v = PseudoBlockVector(v_in, OneTo(N))
35+
m = N ÷ 2 + 1
36+
n = 4(m-1) + 1
37+
mat = zeros(T, m, n)
38+
for K in blockaxes(v,1)
39+
= Int(K)
40+
w = v[K]
41+
if isodd(K̃)
42+
mat[K̃÷2 + 1,1] = w[1]
43+
for j = 2:2:-1
44+
mat[K̃÷2-j÷2+1,2(j-1)+2] = w[j]
45+
mat[K̃÷2-j÷2+1,2(j-1)+3] = w[j+1]
46+
end
47+
else
48+
for j = 1:2:
49+
mat[K̃÷2-j÷2,2(j-1)+2] = w[j]
50+
mat[K̃÷2-j÷2,2(j-1)+3] = w[j+1]
51+
end
52+
end
53+
end
54+
ModalTrav{T,M}(mat)
55+
end
56+
57+
ModalTrav{T,M}(v::AbstractVector) where {T,M} = convert(ModalTrav{T,M}, v)
58+
ModalTrav{T}(v::AbstractVector) where T = ModalTrav{T,Matrix{T}}(v)
59+
ModalTrav(v::AbstractVector{T}) where T = ModalTrav{T}(v)
60+
61+
copy(A::ModalTrav) = ModalTrav(copy(A.matrix))
62+
63+
_diviffinite(n) = div(n,2,RoundUp)
64+
_diviffinite(n::InfiniteCardinal) = n
65+
66+
axes(A::ModalTrav) = (blockedrange(oneto(_diviffinite(size(A.matrix,2)))),)
67+
68+
getindex(A::ModalTrav, K::Block{1}) = _modaltravgetindex(A.matrix, K)
69+
70+
_modaltravgetindex(mat, K) = _modaltravgetindex(MemoryLayout(mat), mat, K)
71+
function _modaltravgetindex(_, mat, K::Block{1})
72+
k = Int(K)
73+
m = k ÷ 2 + 1
74+
n = 4(m-1) + 1
75+
_modaltravgetindex(Matrix(mat[1:m, 1:n]), K)
76+
end
77+
78+
function _modaltravgetindex(::AbstractStridedLayout, mat, K::Block{1})
79+
k = Int(K)
80+
k == 1 && return mat[1:1]
81+
k == 2 && return mat[1,2:3]
82+
st = stride(mat,2)
83+
if isodd(k)
84+
# nonnegative terms
85+
p = mat[range(k÷2+1, step=4*st-1, length=k÷2+1)]
86+
# negative terms
87+
n = mat[range(k÷2+3*st, step=4*st-1, length=k÷2)]
88+
interlace(p,n)
89+
else
90+
# negative terms
91+
n = mat[range(st+k÷2, step=4*st-1, length=k÷2)]
92+
# positive terms
93+
p = mat[range(2st+k÷2, step=4*st-1, length=k÷2)]
94+
interlace(n,p)
95+
end
96+
end
97+
98+
getindex(A::ModalTrav, k::Int) = A[findblockindex(axes(A,1), k)]
99+
function setindex!(A::ModalTrav, v, k::Int)
100+
Kk = findblockindex(axes(A,1), k)
101+
K,j = block(Kk),blockindex(Kk)
102+
= Int(K)
103+
mat = A.matrix
104+
if isodd(K̃)
105+
if j == 1
106+
mat[K̃÷2 + 1,1] = v
107+
elseif iseven(j)
108+
mat[K̃÷2-j÷2+1,2(j-1)+2] = v
109+
else
110+
mat[K̃÷2-(j-1)÷2+1,2(j-2)+3] = v
111+
end
112+
else
113+
if iseven(j)
114+
mat[K̃÷2-(j-1)÷2,2(j-2)+3] = v
115+
else
116+
mat[K̃÷2-j÷2,2(j-1)+2] = v
117+
end
118+
end
119+
A
120+
end
121+
122+
similar(A::ModalTrav, ::Type{T}) where T = ModalTrav(similar(A.matrix, T))
123+
function fill!(A::ModalTrav, x)
124+
fill!(A.matrix, x)
125+
A
126+
end
127+
128+
struct ModalTravStyle <: AbstractBlockStyle{1} end
129+
130+
ModalTravStyle(::Val{1}) = ModalTravStyle()
131+
132+
BroadcastStyle(::Type{<:ModalTrav}) = ModalTravStyle()
133+
BroadcastStyle(a::ModalTravStyle, b::DefaultArrayStyle{0}) = ModalTravStyle()
134+
BroadcastStyle(a::DefaultArrayStyle{0}, b::ModalTravStyle) = ModalTravStyle()
135+
BroadcastStyle(a::ModalTravStyle, b::DefaultArrayStyle{M}) where {M} = BroadcastStyle(BlockStyle{1}(), b)
136+
BroadcastStyle(a::DefaultArrayStyle{M}, b::ModalTravStyle) where {M} = BroadcastStyle(a, BlockStyle{1}())
137+
138+
function similar(bc::Broadcasted{ModalTravStyle}, ::Type{T}) where T
139+
N = blocklength(axes(bc,1))
140+
n = 2N-1
141+
m = n ÷ 4 + 1
142+
ModalTrav(Matrix{T}(undef,m,n))
143+
end
144+
145+
_modal2matrix(a::ModalTrav) = a.matrix
146+
_modal2matrix(a::Broadcasted) = broadcasted(a.f, map(_modal2matrix, a.args)...)
147+
_modal2matrix(a) = a
148+
149+
function copyto!(dest::ModalTrav, bc::Broadcasted{ModalTravStyle})
150+
broadcast!(bc.f, dest.matrix, map(_modal2matrix, bc.args)...)
151+
dest
152+
end
153+
154+
function resize!(a::ModalTrav, N::Block{1})
155+
n = 2Int(N)-1
156+
m = n ÷ 4 + 1
157+
ModalTrav(a.matrix[1:m,1:n])
158+
end
159+
160+
161+
"""
162+
ModalInterlace(ops, (M,N), (l,u))
163+
164+
interlaces the entries of a vector of banded matrices
165+
acting on the different Fourier modes. That is, a ModalInterlace
166+
multiplying a DiagTrav is the same as the operators multiplying the matrix
167+
that the DiagTrav wraps. We assume the same operator acts on the Sin and Cos.
3168
"""
4169
struct ModalInterlace{T, MMNN<:Tuple} <: AbstractBandedBlockBandedMatrix{T}
5170
ops
@@ -8,7 +173,7 @@ struct ModalInterlace{T, MMNN<:Tuple} <: AbstractBandedBlockBandedMatrix{T}
8173
end
9174

10175
ModalInterlace{T}(ops, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}) where T = ModalInterlace{T,typeof(MN)}(ops, MN, bandwidths)
11-
ModalInterlace(ops::AbstractVector{<:AbstractMatrix{T}}, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}) where T = ModalInterlace{T}(ops, MN, bandwidths)
176+
ModalInterlace(ops::AbstractVector{<:AbstractMatrix}, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}) = ModalInterlace{eltype(eltype(ops))}(ops, MN, bandwidths)
12177

13178
axes(Z::ModalInterlace) = blockedrange.(oneto.(Z.MN))
14179

@@ -52,8 +217,36 @@ function sub_materialize(::ModalInterlaceLayout, V::AbstractMatrix{T}) where T
52217
KR,JR = kr.block,jr.block
53218
M,N = Int(last(KR)), Int(last(JR))
54219
R = parent(V)
55-
ModalInterlace{T}([R.ops[m][1:(M-m+2)÷2,1:(N-m+2)÷2] for m=1:min(N,M)], (M,N), R.bandwidths)
220+
ModalInterlace{T}([layout_getindex(R.ops[m],1:(M-m+2)÷2,1:(N-m+2)÷2) for m=1:min(N,M)], (M,N), R.bandwidths)
56221
end
57222

58223
# act like lazy array
59-
Base.BroadcastStyle(::Type{<:ModalInterlace{<:Any,NTuple{2,InfiniteCardinal{0}}}}) = LazyArrayStyle{2}()
224+
BroadcastStyle(::Type{<:ModalInterlace{<:Any,NTuple{2,InfiniteCardinal{0}}}}) = LazyArrayStyle{2}()
225+
226+
# TODO: overload muladd!
227+
function *(A::ModalInterlace, b::ModalTrav)
228+
M = b.matrix
229+
ret = ModalTrav{promote_type(eltype(A), eltype(b))}(undef, size(A,1)).matrix
230+
mul!(view(ret,:,1), A.ops[1], M[:,1])
231+
for j = 1:size(ret,2)÷4
232+
mul!(@view(ret[1:end-j,4j-2]), A.ops[2j], @view(M[1:end-j,4j-2]))
233+
mul!(@view(ret[1:end-j,4j-1]), A.ops[2j], @view(M[1:end-j,4j-1]))
234+
mul!(@view(ret[1:end-j,4j]), A.ops[2j+1], @view(M[1:end-j,4j]))
235+
mul!(@view(ret[1:end-j,4j+1]), A.ops[2j+1], @view(M[1:end-j,4j+1]))
236+
end
237+
ModalTrav(ret)
238+
end
239+
240+
241+
function \(A::ModalInterlace, b::ModalTrav)
242+
M = b.matrix
243+
ret = similar(M, promote_type(eltype(A),eltype(b)))
244+
ldiv!(view(ret,:,1), A.ops[1], M[:,1])
245+
for j = 1:size(M,2)÷4
246+
ldiv!(@view(ret[1:end-j,4j-2]), A.ops[2j], @view(M[1:end-j,4j-2]))
247+
ldiv!(@view(ret[1:end-j,4j-1]), A.ops[2j], @view(M[1:end-j,4j-1]))
248+
ldiv!(@view(ret[1:end-j,4j]), A.ops[2j+1], @view(M[1:end-j,4j]))
249+
ldiv!(@view(ret[1:end-j,4j+1]), A.ops[2j+1], @view(M[1:end-j,4j+1]))
250+
end
251+
ModalTrav(ret)
252+
end

src/MultivariateOrthogonalPolynomials.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ using ClassicalOrthogonalPolynomials, FastTransforms, BlockBandedMatrices, Block
66
LazyArrays, SpecialFunctions, LinearAlgebra, BandedMatrices, LazyBandedMatrices, ArrayLayouts,
77
HarmonicOrthogonalPolynomials
88

9-
import Base: axes, in, ==, *, ^, \, copy, OneTo, getindex, size, oneto, all, resize!
9+
import Base: axes, in, ==, *, ^, \, copy, copyto!, OneTo, getindex, size, oneto, all, resize!, BroadcastStyle, similar, fill!, setindex!, convert
10+
import Base.Broadcast: Broadcasted, broadcasted, DefaultArrayStyle
1011
import DomainSets: boundary
1112

1213
import QuasiArrays: LazyQuasiMatrix, LazyQuasiArrayStyle
1314
import ContinuumArrays: @simplify, Weight, weight, grid, plotgrid, TransformFactorization, ExpansionLayout, plotvalues, unweighted
1415

1516
import ArrayLayouts: MemoryLayout, sublayout, sub_materialize
16-
import BlockArrays: block, blockindex, BlockSlice, viewblock, blockcolsupport
17+
import BlockArrays: block, blockindex, BlockSlice, viewblock, blockcolsupport, AbstractBlockStyle, BlockStyle
1718
import BlockBandedMatrices: _BandedBlockBandedMatrix, AbstractBandedBlockBandedMatrix, _BandedMatrix, blockbandwidths, subblockbandwidths
1819
import LinearAlgebra: factorize
1920
import LazyArrays: arguments, paddeddata, LazyArrayStyle, LazyLayout

0 commit comments

Comments
 (0)