Skip to content

Commit e815efe

Browse files
committed
adaptively populate UX
1 parent b8c29cf commit e815efe

File tree

2 files changed

+45
-36
lines changed

2 files changed

+45
-36
lines changed

src/banded/tridiagonalconjugation.jl

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ function upper_mul_tri_triview!(UX::Tridiagonal, U::BandedMatrix, X::Tridiagonal
2424
# Tridiagonal bands can be resized
2525
@assert length(Xdl)+1 == length(Xd) == length(Xdu)+1 == length(UXdl)+1 == length(UXd) == length(UXdu)+1 == n
2626

27-
UX, bⱼ, aⱼ, cⱼ, cⱼ₋₁ = initiate_upper_mul_tri_triview!(UX, U, X)
28-
UX, bⱼ, aⱼ, cⱼ, cⱼ₋₁ = main_upper_mul_tri_triview!(UX, U, X, 2:n-2, bⱼ, aⱼ, cⱼ, cⱼ₋₁)
29-
finalize_upper_mul_tri_triview!(UX, U, X, n-1, bⱼ, aⱼ, cⱼ, cⱼ₋₁)
27+
UX, bₖ, aₖ, cₖ, cₖ₋₁ = initiate_upper_mul_tri_triview!(UX, U, X)
28+
UX, bₖ, aₖ, cₖ, cₖ₋₁ = main_upper_mul_tri_triview!(UX, U, X, 2:n-2, bₖ, aₖ, cₖ, cₖ₋₁)
29+
finalize_upper_mul_tri_triview!(UX, U, X, n-1, bₖ, aₖ, cₖ, cₖ₋₁)
3030
end
3131

3232

@@ -37,50 +37,50 @@ function initiate_upper_mul_tri_triview!(UX, U, X)
3737

3838
l,u = bandwidths(U)
3939

40-
j = 1
41-
aⱼ, cⱼ = Xd[1], Xdl[1]
42-
Uⱼⱼ, Uⱼⱼ₊₁, Uⱼⱼ₊₂ = Udat[u+1,1], Udat[u,2], Udat[u-1,3] # U[j,j], U[j,j+1], U[j,j+2]
43-
UXd[1] = Uⱼⱼ*aⱼ + Uⱼⱼ₊₁*cⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
44-
bⱼ, aⱼ, cⱼ, cⱼ₋₁ = Xdu[1], Xd[2], Xdl[2], cⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
45-
UXdu[1] = Uⱼⱼ*bⱼ + Uⱼⱼ₊₁*aⱼ + Uⱼⱼ₊₂*cⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+1]*X[j+1,j]
40+
k = 1
41+
aₖ, cₖ = Xd[1], Xdl[1]
42+
Uₖₖ, Uₖₖ₊₁, Uₖₖ₊₂ = Udat[u+1,1], Udat[u,2], Udat[u-1,3] # U[k,k], U[k,k+1], U[k,k+2]
43+
UXd[1] = Uₖₖ*aₖ + Uₖₖ₊₁*cₖ # UX[k,k] = U[k,k]*X[k,k] + U[k,k+1]*X[k+1,k]
44+
bₖ, aₖ, cₖ, cₖ₋₁ = Xdu[1], Xd[2], Xdl[2], cₖ # X[k,k+1], X[k+1,k+1], X[k+2,k+1], X[k+1,k]
45+
UXdu[1] = Uₖₖ*bₖ + Uₖₖ₊₁*aₖ + Uₖₖ₊₂*cₖ # UX[k,k+1] = U[k,k]*X[k,k+1] + U[k,k+1]*X[k+1,k+1] + U[k,k+1]*X[k+1,k]
4646

47-
UX, bⱼ, aⱼ, cⱼ, cⱼ₋₁
47+
UX, bₖ, aₖ, cₖ, cₖ₋₁
4848
end
4949

50-
51-
function main_upper_mul_tri_triview!(UX, U, X, jr, bⱼ, aⱼ, cⱼ, cⱼ₋₁)
50+
# fills in the rows kr of UX
51+
function main_upper_mul_tri_triview!(UX, U, X, kr, bₖ=X.du[kr[1]-1], aₖ=X.d[kr[1]], cₖ=X.dl[kr[1]], cₖ₋₁=X.du[kr[1]-1])
5252
Xdl, Xd, Xdu = X.dl, X.d, X.du
5353
UXdl, UXd, UXdu = UX.dl, UX.d, UX.du
5454
Udat = U.data
5555
l,u = bandwidths(U)
5656

57-
@inbounds for j = jr
58-
Uⱼⱼ, Uⱼⱼ₊₁, Uⱼⱼ₊₂ = Udat[u+1,j], Udat[u,j+1], Udat[u-1,j+2] # U[j,j], U[j,j+1], U[j,j+2]
59-
UXdl[j-1] = Uⱼⱼ*cⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
60-
UXd[j] = Uⱼⱼ*aⱼ + Uⱼⱼ₊₁*cⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
61-
bⱼ, aⱼ, cⱼ, cⱼ₋₁ = Xdu[j], Xd[j+1], Xdl[j+1], cⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
62-
UXdu[j] = Uⱼⱼ*bⱼ + Uⱼⱼ₊₁*aⱼ + Uⱼⱼ₊₂*cⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+2]*X[j+2,j+1]
57+
for k = kr
58+
Uₖₖ, Uₖₖ₊₁, Uₖₖ₊₂ = Udat[u+1,k], Udat[u,k+1], Udat[u-1,k+2] # U[k,k], U[k,k+1], U[k,k+2]
59+
UXdl[k-1] = Uₖₖ*cₖ₋₁ # UX[k,k-1] = U[k,k]*X[k,k-1]
60+
UXd[k] = Uₖₖ*aₖ + Uₖₖ₊₁*cₖ # UX[k,k] = U[k,k]*X[k,k] + U[k,k+1]*X[k+1,k]
61+
bₖ, aₖ, cₖ, cₖ₋₁ = Xdu[k], Xd[k+1], Xdl[k+1], cₖ # X[k,k+1], X[k+1,k+1], X[k+2,k+1], X[k+1,k]
62+
UXdu[k] = Uₖₖ*bₖ + Uₖₖ₊₁*aₖ + Uₖₖ₊₂*cₖ # UX[k,k+1] = U[k,k]*X[k,k+1] + U[k,k+1]*X[k+1,k+1] + U[k,k+2]*X[k+2,k+1]
6363
end
6464

65-
UX, bⱼ, aⱼ, cⱼ, cⱼ₋₁
65+
UX, bₖ, aₖ, cₖ, cₖ₋₁
6666
end
6767

68-
function finalize_upper_mul_tri_triview!(UX, U, X, j, bⱼ, aⱼ, cⱼ, cⱼ₋₁)
68+
function finalize_upper_mul_tri_triview!(UX, U, X, k, bₖ, aₖ, cₖ, cₖ₋₁)
6969
Xdl, Xd, Xdu = X.dl, X.d, X.du
7070
UXdl, UXd, UXdu = UX.dl, UX.d, UX.du
7171
Udat = U.data
7272
l,u = bandwidths(U)
7373

74-
Uⱼⱼ, Uⱼⱼ₊₁ = Udat[u+1,j], Udat[u,j+1] # U[j,j], U[j,j+1]
75-
UXdl[j-1] = Uⱼⱼ*cⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
76-
UXd[j] = Uⱼⱼ*aⱼ + Uⱼⱼ₊₁*cⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
77-
bⱼ, aⱼ, cⱼ₋₁ = Xdu[j], Xd[j+1], cⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
78-
UXdu[j] = Uⱼⱼ*bⱼ + Uⱼⱼ₊₁*aⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+2]*X[j+2,j+1]
74+
Uₖₖ, Uₖₖ₊₁ = Udat[u+1,k], Udat[u,k+1] # U[k,k], U[k,k+1]
75+
UXdl[k-1] = Uₖₖ*cₖ₋₁ # UX[k,k-1] = U[k,k]*X[k,k-1]
76+
UXd[k] = Uₖₖ*aₖ + Uₖₖ₊₁*cₖ # UX[k,k] = U[k,k]*X[k,k] + U[k,k+1]*X[k+1,k]
77+
bₖ, aₖ, cₖ₋₁ = Xdu[k], Xd[k+1], cₖ # X[k,k+1], X[k+1,k+1], X[k+2,k+1], X[k+1,k]
78+
UXdu[k] = Uₖₖ*bₖ + Uₖₖ₊₁*aₖ # UX[k,k+1] = U[k,k]*X[k,k+1] + U[k,k+1]*X[k+1,k+1] + U[k,k+2]*X[k+2,k+1]
7979

80-
j += 1
81-
Uⱼⱼ = Udat[u+1,j] # U[j,j]
82-
UXdl[j-1] = Uⱼⱼ*cⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
83-
UXd[j] = Uⱼⱼ*aⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
80+
k += 1
81+
Uₖₖ = Udat[u+1,k] # U[k,k]
82+
UXdl[k-1] = Uₖₖ*cₖ₋₁ # UX[k,k-1] = U[k,k]*X[k,k-1]
83+
UXd[k] = Uₖₖ*aₖ # UX[k,k] = U[k,k]*X[k,k] + U[k,k+1]*X[k+1,k]
8484

8585
UX
8686
end
@@ -152,19 +152,24 @@ mutable struct TridiagonalConjugationData{T}
152152
datasize::Int
153153
end
154154

155-
function TridiagonalConjugationData(U, X, V, uplo::Char)
156-
T = promote_type(typeof(inv(V[1, 1])), eltype(U), eltype(C)) # include inv so that we can't get Ints
157-
return BidiagonalConjugationData(U, X, V, Tridiagonal(T[], T[], T[]), Tridiagonal(T[], T[], T[]), 0)
155+
function TridiagonalConjugationData(U, X, V)
156+
T = promote_type(typeof(inv(V[1, 1])), eltype(U), eltype(X)) # include inv so that we can't get Ints
157+
n_init = 100
158+
UX = Tridiagonal(Vector{T}(undef, n_init-1), Vector{T}(undef, n_init), Vector{T}(undef, n_init-1))
159+
Y = Tridiagonal(Vector{T}(undef, n_init-1), Vector{T}(undef, n_init), Vector{T}(undef, n_init-1))
160+
initiate_upper_mul_tri_triview!(UX, U, X) # fill-in 1st row
161+
return TridiagonalConjugationData(U, X, V, UX, Y, 1)
158162
end
159163

164+
TridiagonalConjugationData(U, X) = TridiagonalConjugationData(U, X, U)
165+
160166
copy(data::TridiagonalConjugationData) = TridiagonalConjugationData(copy(data.U), copy(data.X), copy(data.V), copy(data.UX), copy(data.Y), data.datasize)
161167

162168

163169
function resizedata!(data::TridiagonalConjugationData, n)
164-
n 0 && return data
165-
n = max(v, n)
166-
dv, ev = data.dv, data.ev
167-
if n > length(ev) # Avoid O(n²) growing. Note min(length(dv), length(ev)) == length(ev)
170+
n data.datasize && return data
171+
172+
if n > length(data.UX.d) # Avoid O(n²) growing. Note min(length(dv), length(ev)) == length(ev)
168173
resize!(data.UX.dl, 2n)
169174
resize!(data.UX.d, 2n + 1)
170175
resize!(data.UX.du, 2n)
@@ -174,6 +179,8 @@ function resizedata!(data::TridiagonalConjugationData, n)
174179
resize!(data.Y.du, 2n)
175180
end
176181

182+
main_upper_mul_tri_triview!(data.UX, data.U, data.X, data.datasize+1:n)
177183

184+
data.datasize = n
178185
end
179186

test/test_bidiagonalconjugation.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ end
115115
# U*X*inv(U) only depends on Tridiagonal(U*X)
116116
@time Y = InfiniteLinearAlgebra.tri_mul_invupper_triview(UX, U)
117117
@test Tridiagonal(U*X / U) Tridiagonal(UX / U) Y
118+
119+
InfiniteLinearAlgebra.TridiagonalConjugationData(U, X, U)
118120
end
119121
@testset "P -> Ultraspherical(3/2)" begin
120122
R = BandedMatrices._BandedMatrix(Vcat((-1 ./ (1:2:∞))',

0 commit comments

Comments
 (0)