Skip to content

Commit 0780e9a

Browse files
committed
split tri_mul_invupper in 3
1 parent e815efe commit 0780e9a

File tree

1 file changed

+48
-18
lines changed

1 file changed

+48
-18
lines changed

src/banded/tridiagonalconjugation.jl

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
"""
2-
upper_mul_tri_triview(U, X) == Tridiagonal(U*X) where U is Upper triangular BandedMatrix and X is Tridiagonal
3-
"""
1+
2+
# upper_mul_tri_triview(U, X) == Tridiagonal(U*X) where U is Upper triangular BandedMatrix and X is Tridiagonal
43
function upper_mul_tri_triview(U::BandedMatrix, X::Tridiagonal)
54
T = promote_type(eltype(U), eltype(X))
65
n = size(U,1)
76
UX = Tridiagonal(Vector{T}(undef, n-1), Vector{T}(undef, n), Vector{T}(undef, n-1))
8-
7+
98
upper_mul_tri_triview!(UX, U, X)
109
end
1110

1211
function upper_mul_tri_triview!(UX::Tridiagonal, U::BandedMatrix, X::Tridiagonal)
1312
n = size(UX,1)
1413

15-
14+
1615
Xdl, Xd, Xdu = X.dl, X.d, X.du
1716
UXdl, UXd, UXdu = UX.dl, UX.d, UX.du
18-
Udat = U.data
19-
17+
2018
l,u = bandwidths(U)
2119

2220
@assert size(U) == (n,n)
@@ -29,7 +27,7 @@ function upper_mul_tri_triview!(UX::Tridiagonal, U::BandedMatrix, X::Tridiagonal
2927
finalize_upper_mul_tri_triview!(UX, U, X, n-1, bₖ, aₖ, cₖ, cₖ₋₁)
3028
end
3129

32-
30+
# populate first row of UX with UX
3331
function initiate_upper_mul_tri_triview!(UX, U, X)
3432
Xdl, Xd, Xdu = X.dl, X.d, X.du
3533
UXdl, UXd, UXdu = UX.dl, UX.d, UX.du
@@ -65,6 +63,7 @@ function main_upper_mul_tri_triview!(UX, U, X, kr, bₖ=X.du[kr[1]-1], aₖ=X.d[
6563
UX, bₖ, aₖ, cₖ, cₖ₋₁
6664
end
6765

66+
# populate rows k and k+1 of UX, assuming we are at the bottom-right
6867
function finalize_upper_mul_tri_triview!(UX, U, X, k, bₖ, aₖ, cₖ, cₖ₋₁)
6968
Xdl, Xd, Xdu = X.dl, X.d, X.du
7069
UXdl, UXd, UXdu = UX.dl, UX.d, UX.du
@@ -92,43 +91,74 @@ end
9291

9392
tri_mul_invupper_triview(X::Tridiagonal, R::BandedMatrix) = tri_mul_invupper_triview!(similar(X, promote_type(eltype(X), eltype(R))), X, R)
9493

94+
9595
function tri_mul_invupper_triview!(Y::Tridiagonal, X::Tridiagonal, R::BandedMatrix)
9696
n = size(X,1)
9797
Xdl, Xd, Xdu = X.dl, X.d, X.du
9898
Ydl, Yd, Ydu = Y.dl, Y.d, Y.du
99-
Rdat = R.data
100-
99+
101100
l,u = bandwidths(R)
102-
101+
103102
@assert size(R) == (n,n)
104103
@assert l == 0 && u 2
105104
# Tridiagonal bands can be resized
106105
@assert length(Xdl)+1 == length(Xd) == length(Xdu)+1 == length(Ydl)+1 == length(Yd) == length(Ydu)+1 == n
107-
108-
106+
107+
UX, Rₖₖ, Rₖₖ₊₁ = initiate_tri_mul_invupper_triview!(Y, X, R)
108+
UX, Rₖₖ, Rₖₖ₊₁ = main_tri_mul_invupper_triview!(Y, X, R, 2:n-1, Rₖₖ, Rₖₖ₊₁)
109+
finalize_tri_mul_invupper_triview!(Y, X, R, n, Rₖₖ, Rₖₖ₊₁)
110+
end
111+
112+
# populate first row of X/R
113+
function initiate_tri_mul_invupper_triview!(Y, X, R)
114+
Xdl, Xd, Xdu = X.dl, X.d, X.du
115+
Ydl, Yd, Ydu = Y.dl, Y.d, Y.du
116+
Rdat = R.data
117+
118+
l,u = bandwidths(R)
119+
109120
k = 1
110121
aₖ,bₖ = Xd[k], Xdu[k]
111122
Rₖₖ,Rₖₖ₊₁ = Rdat[u+1,k], Rdat[u,k+1] # R[1,1], R[1,2]
112123
Yd[k] = aₖ/Rₖₖ
113124
Ydu[k] = bₖ - aₖ * Rₖₖ₊₁/Rₖₖ
114125

115-
@inbounds for k = 2:n-1
126+
Y, Rₖₖ, Rₖₖ₊₁
127+
end
128+
129+
130+
# populate rows kr of X/R
131+
function main_tri_mul_invupper_triview!(Y::Tridiagonal, X::Tridiagonal, R::BandedMatrix, kr, Rₖₖ=R[first(kr),first(kr)], Rₖₖ₊₁=R[first(kr),first(kr)+1])
132+
Xdl, Xd, Xdu = X.dl, X.d, X.du
133+
Ydl, Yd, Ydu = Y.dl, Y.d, Y.du
134+
Rdat = R.data
135+
l,u = bandwidths(R)
136+
137+
@inbounds for k = kr
116138
cₖ₋₁,aₖ,bₖ = Xdl[k-1], Xd[k], Xdu[k]
117139
Ydl[k-1] = cₖ₋₁/Rₖₖ
118140
Yd[k] = aₖ-cₖ₋₁*Rₖₖ₊₁/Rₖₖ
119141
Ydu[k] = cₖ₋₁/Rₖₖ
120-
Rₖₖ,Rₖₖ₊₁,Rₖ₋₁ₖ₊₁,Rₖ₋₁ₖ = Rdat[u+1,k], Rdat[u,k+1],Rdat[u-1,k+1],Rₖₖ₊₁ # R[2,2], R[2,3], R[1,3]
142+
Rₖₖ,Rₖₖ₊₁,Rₖ₋₁ₖ₊₁,Rₖ₋₁ₖ = Rdat[u+1,k], Rdat[u,k+1],Rdat[u-1,k+1],Rₖₖ₊₁ # R[k,k], R[k,k+1], R[k-1,k]
121143
Yd[k] /= Rₖₖ
122144
Ydu[k-1] /= Rₖₖ
123145
Ydu[k] *= Rₖ₋₁ₖ*Rₖₖ₊₁/Rₖₖ - Rₖ₋₁ₖ₊₁
124146
Ydu[k] += bₖ - aₖ * Rₖₖ₊₁ / Rₖₖ
125147
end
148+
Y, Rₖₖ, Rₖₖ₊₁
149+
end
150+
126151

127-
k = n
152+
# populate row k of X/R, assuming we are at the bottom-right
153+
function finalize_tri_mul_invupper_triview!(Y::Tridiagonal, X::Tridiagonal, R::BandedMatrix, k, Rₖₖ=R[k-1,k-1], Rₖₖ₊₁=R[k-1,k])
154+
Xdl, Xd, Xdu = X.dl, X.d, X.du
155+
Ydl, Yd, Ydu = Y.dl, Y.d, Y.du
156+
Rdat = R.data
157+
l,u = bandwidths(R)
128158
cₖ₋₁,aₖ = Xdl[k-1], Xd[k]
129159
Ydl[k-1] = cₖ₋₁/Rₖₖ
130160
Yd[k] = aₖ-cₖ₋₁*Rₖₖ₊₁/Rₖₖ
131-
Rₖₖ = Rdat[u+1,k] # R[2,2], R[2,3], R[1,3]
161+
Rₖₖ = Rdat[u+1,k] # R[k,k]
132162
Yd[k] /= Rₖₖ
133163
Ydu[k-1] /= Rₖₖ
134164

@@ -173,7 +203,7 @@ function resizedata!(data::TridiagonalConjugationData, n)
173203
resize!(data.UX.dl, 2n)
174204
resize!(data.UX.d, 2n + 1)
175205
resize!(data.UX.du, 2n)
176-
206+
177207
resize!(data.Y.dl, 2n)
178208
resize!(data.Y.d, 2n + 1)
179209
resize!(data.Y.du, 2n)

0 commit comments

Comments
 (0)