From b5473ee2c6b136636040e27ed5f95f1209f1042c Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 25 Feb 2025 21:03:57 +0530 Subject: [PATCH] Specialize `lmul!`/`rmul!` for strided triangular matrices --- src/diagonal.jl | 47 +++++++++++++++++++++++++++++++++++++++++------ test/diagonal.jl | 4 ++-- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/diagonal.jl b/src/diagonal.jl index b5642800..330a0ad6 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -379,6 +379,26 @@ function rmul!(T::Tridiagonal, D::Diagonal) end return T end +for T in [:UpperTriangular, :UnitUpperTriangular, + :LowerTriangular, :UnitLowerTriangular] + @eval rmul!(A::$T{<:Any, <:StridedMatrix}, D::Diagonal) = _rmul!(A, D) + @eval lmul!(D::Diagonal, A::$T{<:Any, <:StridedMatrix}) = _lmul!(D, A) +end +function _rmul!(A::UpperOrLowerTriangular, D::Diagonal) + P = parent(A) + isunit = A isa UnitUpperOrUnitLowerTriangular + isupper = A isa UpperOrUnitUpperTriangular + for col in axes(A,2) + rowstart = isupper ? firstindex(A,1) : col+isunit + rowstop = isupper ? col-isunit : lastindex(A,1) + for row in rowstart:rowstop + P[row, col] *= D.diag[col] + end + end + isunit && _setdiag!(P, identity, D.diag) + TriWrapper = isupper ? UpperTriangular : LowerTriangular + return TriWrapper(P) +end function lmul!(D::Diagonal, B::AbstractVecOrMat) matmul_size_check(size(D), size(B)) @@ -388,6 +408,13 @@ function lmul!(D::Diagonal, B::AbstractVecOrMat) end return B end +# A' = D * A' => A = A * D' +# This uses the fact that D' is a Diagonal +function lmul!(D::Diagonal, A::AdjOrTransAbsMat) + f = wrapperop(A) + rmul!(f(A), f(D)) + A +end # in-place multiplication with a diagonal # T .= D * T @@ -402,12 +429,20 @@ function lmul!(D::Diagonal, T::Tridiagonal) end return T end -# A' = D * A' => A = A * D' -# This uses the fact that D' is a Diagonal -function lmul!(D::Diagonal, A::AdjOrTransAbsMat) - f = wrapperop(A) - rmul!(f(A), f(D)) - A +function _lmul!(D::Diagonal, A::UpperOrLowerTriangular) + P = parent(A) + isunit = A isa UnitUpperOrUnitLowerTriangular + isupper = A isa UpperOrUnitUpperTriangular + for col in axes(A,2) + rowstart = isupper ? firstindex(A,1) : col+isunit + rowstop = isupper ? col-isunit : lastindex(A,1) + for row in rowstart:rowstop + P[row, col] = D.diag[row] * P[row, col] + end + end + isunit && _setdiag!(P, identity, D.diag) + TriWrapper = isupper ? UpperTriangular : LowerTriangular + return TriWrapper(P) end @inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) diff --git a/test/diagonal.jl b/test/diagonal.jl index 4675e6f2..d53dbc78 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1197,11 +1197,11 @@ end outTri = similar(TriA) out = similar(A) # 2 args - for fun in (*, rmul!, rdiv!, /) + @testset for fun in (*, rmul!, rdiv!, /) @test fun(copy(TriA), D)::Tri == fun(Matrix(TriA), D) @test fun(copy(UTriA), D)::Tri == fun(Matrix(UTriA), D) end - for fun in (*, lmul!, ldiv!, \) + @testset for fun in (*, lmul!, ldiv!, \) @test fun(D, copy(TriA))::Tri == fun(D, Matrix(TriA)) @test fun(D, copy(UTriA))::Tri == fun(D, Matrix(UTriA)) end