From 2148fa45a68a4767a6f1663d7fd50fd2be70ef11 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 May 2025 19:05:54 +0530 Subject: [PATCH] Fix scaling block unit triangular matrices --- src/triangular.jl | 31 +++++++++++++------------------ test/triangular.jl | 10 ++++++++++ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/triangular.jl b/src/triangular.jl index 5b476d24..b0af1c34 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1323,6 +1323,15 @@ end # Generic routines # #################### +function _set_diag!(B::UpperOrLowerTriangular, x) + # get a mutable array to modify the diagonal + Bm = parent(B) isa StridedArray ? B : copy!(similar(B), B) + for i in diagind(Bm.data, IndexStyle(Bm.data)) + Bm.data[i] = x + end + Bm +end + for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular)) tstrided = t{<:Any, <:StridedMaybeAdjOrTransMat} @@ -1336,10 +1345,7 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), function (*)(A::$unitt, x::Number) B = $t(A.data)*x - for i in axes(A, 1) - B.data[i,i] = x - end - return B + _set_diag!(B, oneunit(eltype(A)) * x) end (*)(x::Number, A::$t) = $t(x*A.data) @@ -1351,10 +1357,7 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), function (*)(x::Number, A::$unitt) B = x*$t(A.data) - for i in axes(A, 1) - B.data[i,i] = x - end - return B + _set_diag!(B, x * oneunit(eltype(A))) end (/)(A::$t, x::Number) = $t(A.data/x) @@ -1366,11 +1369,7 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), function (/)(A::$unitt, x::Number) B = $t(A.data)/x - invx = inv(x) - for i in axes(A, 1) - B.data[i,i] = invx - end - return B + _set_diag!(B, oneunit(eltype(A)) / x) end (\)(x::Number, A::$t) = $t(x\A.data) @@ -1382,11 +1381,7 @@ for (t, unitt) in ((UpperTriangular, UnitUpperTriangular), function (\)(x::Number, A::$unitt) B = x\$t(A.data) - invx = inv(x) - for i in axes(A, 1) - B.data[i,i] = invx - end - return B + _set_diag!(B, x \ oneunit(eltype(A))) end end end diff --git a/test/triangular.jl b/test/triangular.jl index aeb41aa6..c5dca32d 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -934,4 +934,14 @@ end end end +@testset "block unit triangular scaling" begin + m = SizedArrays.SizedArray{(2,2)}([1 2; 3 4]) + U = UnitUpperTriangular(fill(m, 4, 4)) + M = Matrix{eltype(U)}(U) + @test U/2 == M/2 + @test 2\U == 2\M + @test U*2 == M*2 + @test 2*U == 2*M +end + end # module TestTriangular