Skip to content

Commit b3f02e7

Browse files
authored
Add alloc-free in-place Tridiagonal solves and lu! (#50535)
1 parent 2d74f5e commit b3f02e7

File tree

3 files changed

+106
-20
lines changed

3 files changed

+106
-20
lines changed

src/lu.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -494,27 +494,34 @@ inv!(A::LU{T,<:StridedMatrix}) where {T} =
494494
inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))
495495

496496
# Tridiagonal
497-
498-
# See dgttrf.f
499497
function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T,V}
500-
# Extract values
501498
n = size(A, 1)
502-
503-
# Initialize variables
504-
info = 0
505-
ipiv = Vector{BlasInt}(undef, n)
506-
dl = A.dl
507-
d = A.d
508-
du = A.du
509-
if dl === du
510-
throw(ArgumentError("off-diagonals of `A` must not alias"))
511-
end
512-
# Check if Tridiagonal matrix already has du2 for pivoting
513499
has_du2_defined = isdefined(A, :du2) && length(A.du2) == max(0, n-2)
514500
if has_du2_defined
515501
du2 = A.du2::V
516502
else
517-
du2 = similar(d, max(0, n-2))::V
503+
du2 = similar(A.d, max(0, n-2))::V
504+
end
505+
_lu_tridiag!(A.dl, A.d, A.du, du2, Vector{BlasInt}(undef, n), pivot, check)
506+
end
507+
function lu!(F::LU{<:Any,<:Tridiagonal}, A::Tridiagonal, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
508+
B = F.factors
509+
size(B) == size(A) || throw(DimensionMismatch())
510+
copyto!(B, A)
511+
_lu_tridiag!(B.dl, B.d, B.du, B.du2, F.ipiv, pivot, check)
512+
end
513+
# See dgttrf.f
514+
@inline function _lu_tridiag!(dl, d, du, du2, ipiv, pivot, check)
515+
T = eltype(d)
516+
V = typeof(d)
517+
518+
# Extract values
519+
n = length(d)
520+
521+
# Initialize variables
522+
info = 0
523+
if dl === du
524+
throw(ArgumentError("off-diagonals must not alias"))
518525
end
519526
fill!(du2, 0)
520527

@@ -571,9 +578,8 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(
571578
end
572579
end
573580
end
574-
B = has_du2_defined ? A : Tridiagonal{T,V}(dl, d, du, du2)
575581
check && checknonsingular(info, pivot)
576-
return LU{T,Tridiagonal{T,V},typeof(ipiv)}(B, ipiv, convert(BlasInt, info))
582+
return LU{T,Tridiagonal{T,V},typeof(ipiv)}(Tridiagonal{T,V}(dl, d, du, du2), ipiv, convert(BlasInt, info))
577583
end
578584

579585
factorize(A::Tridiagonal) = lu(A)

src/tridiag.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,3 +876,77 @@ function cholesky(S::SymTridiagonal, ::NoPivot = NoPivot(); check::Bool = true)
876876
T = choltype(eltype(S))
877877
cholesky!(Hermitian(Bidiagonal{T}(diag(S, 0), diag(S, 1), :U)), NoPivot(); check = check)
878878
end
879+
880+
# See dgtsv.f
881+
"""
882+
ldiv!(A::Tridiagonal, B::AbstractVecOrMat) -> B
883+
884+
Compute `A \\ B` in-place by Gaussian elimination with partial pivoting and store the result
885+
in `B`, returning the result. In the process, the diagonals of `A` are overwritten as well.
886+
887+
!!! compat "Julia 1.11"
888+
`ldiv!` for `Tridiagonal` left-hand sides requires at least Julia 1.11.
889+
"""
890+
function ldiv!(A::Tridiagonal, B::AbstractVecOrMat)
891+
LinearAlgebra.require_one_based_indexing(B)
892+
n = size(A, 1)
893+
if n != size(B,1)
894+
throw(DimensionMismatch("matrix has dimensions ($n,$n) but right hand side has $(size(B,1)) rows"))
895+
end
896+
nrhs = size(B, 2)
897+
898+
# Initialize variables
899+
dl = A.dl
900+
d = A.d
901+
du = A.du
902+
if dl === du
903+
throw(ArgumentError("off-diagonals of `A` must not alias"))
904+
end
905+
906+
@inbounds begin
907+
for i in 1:n-1
908+
# pivot or not?
909+
if abs(d[i]) >= abs(dl[i])
910+
# No interchange
911+
if d[i] != 0
912+
fact = dl[i]/d[i]
913+
d[i+1] -= fact*du[i]
914+
for j in 1:nrhs
915+
B[i+1,j] -= fact*B[i,j]
916+
end
917+
else
918+
checknonsingular(i, RowMaximum())
919+
end
920+
i < n-1 && (dl[i] = 0)
921+
else
922+
# Interchange
923+
fact = d[i]/dl[i]
924+
d[i] = dl[i]
925+
tmp = d[i+1]
926+
d[i+1] = du[i] - fact*tmp
927+
du[i] = tmp
928+
if i < n-1
929+
dl[i] = du[i+1]
930+
du[i+1] = -fact*dl[i]
931+
end
932+
for j in 1:nrhs
933+
temp = B[i,j]
934+
B[i,j] = B[i+1,j]
935+
B[i+1,j] = temp - fact*B[i+1,j]
936+
end
937+
end
938+
end
939+
iszero(d[n]) && checknonsingular(n, RowMaximum())
940+
# backward substitution
941+
for j in 1:nrhs
942+
B[n,j] /= d[n]
943+
if n > 1
944+
B[n-1,j] = (B[n-1,j] - du[n-1]*B[n,j])/d[n-1]
945+
end
946+
for i in n-2:-1:1
947+
B[i,j] = (B[i,j] - du[i]*B[i+1,j] - dl[i]*B[i+2,j]) / d[i]
948+
end
949+
end
950+
end
951+
return B
952+
end

test/tridiag.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,17 +434,23 @@ end
434434
end
435435
else # mat_type is Tridiagonal
436436
@testset "tridiagonal linear algebra" begin
437-
for (BB, vv) in ((copy(B), copy(v)), (view(B, 1:n, 1), view(v, 1:n)))
437+
for vv in (copy(v), view(copy(v), 1:n))
438438
@test A*vv fA*vv
439439
invFv = fA\vv
440440
@test A\vv invFv
441-
# @test Base.solve(T,v) ≈ invFv
442-
# @test Base.solve(T, B) ≈ F\B
443441
Tlu = factorize(A)
444442
x = Tlu\vv
445443
@test x invFv
446444
end
445+
elty != Int && @test A \ v ldiv!(copy(A), copy(v))
447446
end
447+
F = lu(A)
448+
L1, U1, p1 = F
449+
G = lu!(F, 2A)
450+
L2, U2, p2 = F
451+
@test L1 L2
452+
@test 2U1 U2
453+
@test p1 == p2
448454
end
449455
@testset "generalized dot" begin
450456
x = fill(convert(elty, 1), n)

0 commit comments

Comments
 (0)