Skip to content

Commit e64a854

Browse files
committed
Fast Tridiagonal(U*X)
1 parent 1a59fa8 commit e64a854

File tree

1 file changed

+78
-12
lines changed

1 file changed

+78
-12
lines changed

test/test_bidiagonalconjugation.jl

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,32 +102,98 @@ end
102102

103103

104104
"""
105-
tridiagonal(A, B) == Tridiagonal(A*B)
105+
upper_mul_tri_triview(A, B) == Tridiagonal(A*B) where A is Upper triangular BandedMatrix and B is
106106
"""
107-
function tridiagonalmul(A, B)
108-
T = promote_type(eltype(A), eltype(B))
109-
UX = Tridiagonal(Vector{T}(undef, n-1), Vector{T}(undef, n), Vector{T}(undef, n-1))
107+
function upper_mul_tri_triview(U, X)
108+
T = promote_type(eltype(U), eltype(X))
109+
n = size(U,1)
110+
upper_mul_tri_triview!(Tridiagonal(Vector{T}(undef, n-1), Vector{T}(undef, n), Vector{T}(undef, n-1)), U, X)
111+
end
112+
113+
114+
# function upper_mul_tri_triview!(UX::Tridiagonal, U::BandedMatrix, X::Tridiagonal)
115+
# n = size(U,1)
116+
# @inbounds for j = 1:n-1
117+
# UX.d[j] = U.data[3,j]*X.d[j] + U.data[2,j]*X.dl[j]
118+
# end
119+
# UX.d[n] = U.data[3,n]*X.d[n]
120+
121+
# @inbounds for j = 1:n-1
122+
# UX.dl[j] = U.data[3,j+1]*X.dl[j]
123+
# end
124+
125+
# @inbounds for j = 1:n-2
126+
# UX.du[j] = U.data[3,j]*X.du[j] + U.data[2,j+1]*X.d[j+1] + U.data[1,j+2]*X.dl[j+1]
127+
# end
128+
129+
# UX.du[n-1] = U.data[3,n-1]*X.du[n-1] + U.data[2,n]*X.d[n]
130+
131+
# UX
132+
# end
133+
134+
function upper_mul_tri_triview!(UX::Tridiagonal, U::BandedMatrix, X::Tridiagonal)
135+
n = size(U,1)
136+
j = 1
137+
Xⱼⱼ, Xⱼ₊₁ⱼ = X.d[1], X.dl[1]
138+
Uⱼⱼ, Uⱼⱼ₊₁, Uⱼⱼ₊₂ = U.data[3,1], U.data[2,2], U.data[1,3] # U[j,j], U[j,j+1], U[j,j+2]
139+
UX.d[1] = Uⱼⱼ*Xⱼⱼ + Uⱼⱼ₊₁*Xⱼ₊₁ⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
140+
Xⱼⱼ₊₁, Xⱼⱼ, Xⱼ₊₁ⱼ, Xⱼⱼ₋₁ = X.du[1], X.d[2], X.dl[2], Xⱼ₊₁ⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
141+
UX.du[1] = Uⱼⱼ*Xⱼⱼ₊₁ + Uⱼⱼ₊₁*Xⱼⱼ + Uⱼⱼ₊₂*Xⱼ₊₁ⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+1]*X[j+1,j]
142+
143+
@inbounds for j = 2:n-2
144+
Uⱼⱼ, Uⱼⱼ₊₁, Uⱼⱼ₊₂ = U.data[3,j], U.data[2,j+1], U.data[1,j+2] # U[j,j], U[j,j+1], U[j,j+2]
145+
UX.dl[j-1] = Uⱼⱼ*Xⱼⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
146+
UX.d[j] = Uⱼⱼ*Xⱼⱼ + Uⱼⱼ₊₁*Xⱼ₊₁ⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
147+
Xⱼⱼ₊₁, Xⱼⱼ, Xⱼ₊₁ⱼ, Xⱼⱼ₋₁ = X.du[j], X.d[j+1], X.dl[j+1], Xⱼ₊₁ⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
148+
UX.du[j] = Uⱼⱼ*Xⱼⱼ₊₁ + Uⱼⱼ₊₁*Xⱼⱼ + Uⱼⱼ₊₂*Xⱼ₊₁ⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+2]*X[j+2,j+1]
149+
end
150+
151+
j = n-1
152+
Uⱼⱼ, Uⱼⱼ₊₁ = U.data[3,j], U.data[2,j+1] # U[j,j], U[j,j+1]
153+
UX.dl[j-1] = Uⱼⱼ*Xⱼⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
154+
UX.d[j] = Uⱼⱼ*Xⱼⱼ + Uⱼⱼ₊₁*Xⱼ₊₁ⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
155+
Xⱼⱼ₊₁, Xⱼⱼ, Xⱼⱼ₋₁ = X.du[j], X.d[j+1], Xⱼ₊₁ⱼ # X[j,j+1], X[j+1,j+1], X[j+2,j+1], X[j+1,j]
156+
UX.du[j] = Uⱼⱼ*Xⱼⱼ₊₁ + Uⱼⱼ₊₁*Xⱼⱼ # UX[j,j+1] = U[j,j]*X[j,j+1] + U[j,j+1]*X[j+1,j+1] + U[j,j+2]*X[j+2,j+1]
157+
158+
j = n
159+
Uⱼⱼ = U.data[3,j] # U[j,j]
160+
UX.dl[j-1] = Uⱼⱼ*Xⱼⱼ₋₁ # UX[j,j-1] = U[j,j]*X[j,j-1]
161+
UX.d[j] = Uⱼⱼ*Xⱼⱼ # UX[j,j] = U[j,j]*X[j,j] + U[j,j+1]*X[j+1,j]
162+
163+
UX
164+
end
165+
166+
tri_mul_invupper_triview(UX::Tridiagonal, R::BandedMatrix) = tri_mul_invupper_triview!(similar(UX, promote_type(eltype(UX), eltype(R))), UX, R)
110167

111-
for j = 1:n-1
112-
UX.d[j] = U.data[3,j]*X.d[j] + U.data[2,j]*X.dl[j]
168+
function tri_mul_invupper_triview!(X, UX, R)
169+
@inbounds for j = 1:n-1
170+
UX.d[j] = UX.d[j]/ U.data[3,j]*X.d[j] + U.data[2,j]*X.dl[j]
113171
end
114172
UX.d[n] = U.data[3,n]*X.d[n]
115173

116-
for j = 1:n-1
117-
UX.dl[j] = U.data[3,j]*X.d[j] + U.data[2,j]*X.dl[j]
174+
@inbounds for j = 1:n-1
175+
UX.dl[j] = U.data[3,j+1]*X.dl[j]
176+
end
177+
178+
@inbounds for j = 1:n-2
179+
UX.du[j] = U.data[3,j]*X.du[j] + U.data[2,j+1]*X.d[j+1] + U.data[1,j+2]*X.dl[j+1]
118180
end
181+
182+
UX.du[n-1] = U.data[3,n-1]*X.du[n-1] + U.data[2,n]*X.d[n]
183+
184+
UX
119185
end
120186

187+
121188
@testset "TridiagonalConjugation" begin
122189
R0 = BandedMatrices._BandedMatrix(Vcat(-Ones(1,∞)/2,
123190
Zeros(1,∞),
124191
Hcat(Ones(1,1),Ones(1,∞)/2)), ℵ₀, 0,2)
125192
X_T = LazyBandedMatrices.Tridiagonal(Vcat(1.0, Fill(1/2,∞)), Zeros(∞), Fill(1/2,∞))
126193

127-
n = 1000; @time U = V = R0[1:n,1:n]
128-
@time X = Tridiagonal(Vector(X_T.dl[1:n-1]), Vector(X_T.d[1:n]), Vector(X_T.du[1:n-1]))
194+
n = 1000; @time U = V = R0[1:n,1:n];
195+
@time X = Tridiagonal(Vector(X_T.dl[1:n-1]), Vector(X_T.d[1:n]), Vector(X_T.du[1:n-1]));
129196

130-
@time U*X
131-
T = Float64
197+
@test Tridiagonal(U*X) upper_mul_tri_triview(U, X)
132198

133199
end

0 commit comments

Comments
 (0)