Skip to content

Commit 763d37c

Browse files
dkarraschLilithHafner
authored andcommitted
Make (Sym)Tridiag-Diagonal solves return TriDiagonal (JuliaLang#42744)
1 parent c503e40 commit 763d37c

File tree

2 files changed

+139
-4
lines changed

2 files changed

+139
-4
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
363363
require_one_based_indexing(A)
364364
dd = D.diag
365365
m, n = size(A, 1), size(A, 2)
366-
if (k = length(dd)) n
366+
if (k = length(dd)) != n
367367
throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
368368
end
369369
@inbounds for j in 1:n
@@ -395,7 +395,7 @@ end
395395
\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B)
396396
/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D)
397397
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
398-
n, k = length(Db.diag), length(Db.diag)
398+
n, k = length(Db.diag), length(Da.diag)
399399
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
400400
j = findfirst(iszero, Da.diag)
401401
isnothing(j) || throw(SingularException(j))
@@ -404,6 +404,88 @@ function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
404404
end
405405
ldiv!(Dc::Diagonal, Da::Diagonal, Db::Diagonal) = Diagonal(ldiv!(Dc.diag, Da, Db.diag))
406406

407+
# optimizations for (Sym)Tridiagonal and Diagonal
408+
@propagate_inbounds _getudiag(T::Tridiagonal, i) = T.du[i]
409+
@propagate_inbounds _getudiag(S::SymTridiagonal, i) = S.ev[i]
410+
@propagate_inbounds _getdiag(T::Tridiagonal, i) = T.d[i]
411+
@propagate_inbounds _getdiag(S::SymTridiagonal, i) = symmetric(S.dv[i], :U)::symmetric_type(eltype(S.dv))
412+
@propagate_inbounds _getldiag(T::Tridiagonal, i) = T.dl[i]
413+
@propagate_inbounds _getldiag(S::SymTridiagonal, i) = transpose(S.ev[i])
414+
415+
function (\)(D::Diagonal, S::SymTridiagonal)
416+
T = promote_op(\, eltype(D), eltype(S))
417+
du = similar(S.ev, T, max(length(S.dv)-1, 0))
418+
d = similar(S.dv, T, length(S.dv))
419+
dl = similar(S.ev, T, max(length(S.dv)-1, 0))
420+
ldiv!(Tridiagonal(dl, d, du), D, S)
421+
end
422+
(\)(D::Diagonal, T::Tridiagonal) = ldiv!(similar(T, promote_op(\, eltype(D), eltype(T))), D, T)
423+
function ldiv!(T::Tridiagonal, D::Diagonal, S::Union{SymTridiagonal,Tridiagonal})
424+
m = size(S, 1)
425+
dd = D.diag
426+
if (k = length(dd)) != m
427+
throw(DimensionMismatch("diagonal matrix is $k by $k but right hand side has $m rows"))
428+
end
429+
if length(T.d) != m
430+
throw(DimensionMismatch("target matrix size $(size(T)) does not match input matrix size $(size(S))"))
431+
end
432+
m == 0 && return T
433+
j = findfirst(iszero, dd)
434+
isnothing(j) || throw(SingularException(j))
435+
ddj = dd[1]
436+
T.d[1] = ddj \ _getdiag(S, 1)
437+
@inbounds if m > 1
438+
T.du[1] = ddj \ _getudiag(S, 1)
439+
for j in 2:m-1
440+
ddj = dd[j]
441+
T.dl[j-1] = ddj \ _getldiag(S, j-1)
442+
T.d[j] = ddj \ _getdiag(S, j)
443+
T.du[j] = ddj \ _getudiag(S, j)
444+
end
445+
ddj = dd[m]
446+
T.dl[m-1] = ddj \ _getldiag(S, m-1)
447+
T.d[m] = ddj \ _getdiag(S, m)
448+
end
449+
return T
450+
end
451+
452+
function (/)(S::SymTridiagonal, D::Diagonal)
453+
T = promote_op(\, eltype(D), eltype(S))
454+
du = similar(S.ev, T, max(length(S.dv)-1, 0))
455+
d = similar(S.dv, T, length(S.dv))
456+
dl = similar(S.ev, T, max(length(S.dv)-1, 0))
457+
_rdiv!(Tridiagonal(dl, d, du), S, D)
458+
end
459+
(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(similar(T, promote_op(/, eltype(T), eltype(D))), T, D)
460+
function _rdiv!(T::Tridiagonal, S::Union{SymTridiagonal,Tridiagonal}, D::Diagonal)
461+
n = size(S, 2)
462+
dd = D.diag
463+
if (k = length(dd)) != n
464+
throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
465+
end
466+
if length(T.d) != n
467+
throw(DimensionMismatch("target matrix size $(size(T)) does not match input matrix size $(size(S))"))
468+
end
469+
n == 0 && return T
470+
j = findfirst(iszero, dd)
471+
isnothing(j) || throw(SingularException(j))
472+
ddj = dd[1]
473+
T.d[1] = _getdiag(S, 1) / ddj
474+
@inbounds if n > 1
475+
T.dl[1] = _getldiag(S, 1) / ddj
476+
for j in 2:n-1
477+
ddj = dd[j]
478+
T.dl[j] = _getldiag(S, j) / ddj
479+
T.d[j] = _getdiag(S, j) / ddj
480+
T.du[j-1] = _getudiag(S, j-1) / ddj
481+
end
482+
ddj = dd[n]
483+
T.d[n] = _getdiag(S, n) / ddj
484+
T.du[n-1] = _getudiag(S, n-1) / ddj
485+
end
486+
return T
487+
end
488+
407489
# Optimizations for [l/r]mul!, l/rdiv!, *, / and \ between Triangular and Diagonal.
408490
# These functions are generally more efficient if we calculate the whole data field.
409491
# The following code implements them in a unified pattern to avoid missing.
@@ -623,7 +705,7 @@ dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A))
623705

624706
function _mapreduce_prod(f, x, D::Diagonal, y)
625707
if isempty(x) && isempty(D) && isempty(y)
626-
return zero(Base.promote_op(f, eltype(x), eltype(D), eltype(y)))
708+
return zero(promote_op(f, eltype(x), eltype(D), eltype(y)))
627709
else
628710
return mapreduce(t -> f(t[1], t[2], t[3]), +, zip(x, D.diag, y))
629711
end

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module TestDiagonal
44

55
using Test, LinearAlgebra, SparseArrays, Random
6-
using LinearAlgebra: mul!, mul!, rmul!, lmul!, ldiv!, rdiv!, BlasFloat, BlasComplex, SingularException
6+
using LinearAlgebra: BlasFloat, BlasComplex
77

88
n=12 #Size of matrix problem to test
99
Random.seed!(1)
@@ -776,6 +776,59 @@ end
776776
end
777777
end
778778

779+
@testset "(Sym)Tridiagonal division by Diagonal" begin
780+
for K in (5, 1), elty in (Float64, ComplexF32), overlength in (1, 0)
781+
S = SymTridiagonal(randn(elty, K), randn(elty, K-overlength))
782+
T = Tridiagonal(randn(elty, K-1), randn(elty, K), randn(elty, K-1))
783+
D = Diagonal(randn(elty, K))
784+
D0 = Diagonal(zeros(elty, K))
785+
@test (D \ S)::Tridiagonal{elty} == Tridiagonal(Matrix(D) \ Matrix(S))
786+
@test (D \ T)::Tridiagonal{elty} == Tridiagonal(Matrix(D) \ Matrix(T))
787+
@test (S / D)::Tridiagonal{elty} == Tridiagonal(Matrix(S) / Matrix(D))
788+
@test (T / D)::Tridiagonal{elty} == Tridiagonal(Matrix(T) / Matrix(D))
789+
@test_throws SingularException D0 \ S
790+
@test_throws SingularException D0 \ T
791+
@test_throws SingularException S / D0
792+
@test_throws SingularException T / D0
793+
end
794+
# 0-length case
795+
S = SymTridiagonal(Float64[], Float64[])
796+
T = Tridiagonal(Float64[], Float64[], Float64[])
797+
D = Diagonal(Float64[])
798+
@test (D \ S)::Tridiagonal{Float64} == T
799+
@test (D \ T)::Tridiagonal{Float64} == T
800+
@test (S / D)::Tridiagonal{Float64} == T
801+
@test (T / D)::Tridiagonal{Float64} == T
802+
# matrix eltype case
803+
K = 5
804+
for elty in (Float64, ComplexF32), overlength in (1, 0)
805+
S = SymTridiagonal([rand(elty, 2, 2) for _ in 1:K], [rand(elty, 2, 2) for _ in 1:K-overlength])
806+
T = Tridiagonal([rand(elty, 2, 2) for _ in 1:K-1], [rand(elty, 2, 2) for _ in 1:K], [rand(elty, 2, 2) for _ in 1:K-1])
807+
D = Diagonal(randn(elty, K))
808+
SM = fill(zeros(elty, 2, 2), K, K)
809+
TM = copy(SM)
810+
SM[1,1] = S[1,1]; TM[1,1] = T[1,1]
811+
for j in 2:K
812+
SM[j,j-1] = S[j,j-1]; SM[j,j] = S[j,j]; SM[j-1,j] = S[j-1,j]
813+
TM[j,j-1] = T[j,j-1]; TM[j,j] = T[j,j]; TM[j-1,j] = T[j-1,j]
814+
end
815+
for (M, Mm) in ((S, SM), (T, TM))
816+
DS = D \ M
817+
@test DS isa Tridiagonal
818+
DM = D \ Mm
819+
for i in -1:1; @test diag(DS, i) diag(DM, i) end
820+
end
821+
end
822+
# eltype promotion case
823+
S = SymTridiagonal(rand(-20:20, K), rand(-20:20, K-1))
824+
T = Tridiagonal(rand(-20:20, K-1), rand(-20:20, K), rand(-20:20, K-1))
825+
D = Diagonal(rand(1:20, K))
826+
@test (D \ S)::Tridiagonal{Float64} == Tridiagonal(Matrix(D) \ Matrix(S))
827+
@test (D \ T)::Tridiagonal{Float64} == Tridiagonal(Matrix(D) \ Matrix(T))
828+
@test (S / D)::Tridiagonal{Float64} == Tridiagonal(Matrix(S) / Matrix(D))
829+
@test (T / D)::Tridiagonal{Float64} == Tridiagonal(Matrix(T) / Matrix(D))
830+
end
831+
779832
@testset "eigenvalue sorting" begin
780833
D = Diagonal([0.4, 0.2, -1.3])
781834
@test eigvals(D) == eigen(D).values == [0.4, 0.2, -1.3] # not sorted by default

0 commit comments

Comments
 (0)