Skip to content

Commit 168f7e7

Browse files
committed
Specialize adding/subtracting mixed Upper/LowerTriangular (#56149)
1 parent 1f5c07f commit 168f7e7

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ UnitUpperTriangular
152152
const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTriangular{T,S}}
153153
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
154154
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
155+
const UnitUpperOrUnitLowerTriangular{T,S} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}}
155156

156157
uppertriangular(M) = UpperTriangular(M)
157158
lowertriangular(M) = LowerTriangular(M)
@@ -221,6 +222,16 @@ function Matrix{T}(A::UnitUpperTriangular) where T
221222
B
222223
end
223224

225+
function full(A::Union{UpperTriangular,LowerTriangular})
226+
return _triangularize(A)(parent(A))
227+
end
228+
function full(A::UnitUpperOrUnitLowerTriangular)
229+
isupper = A isa UnitUpperTriangular
230+
Ap = _triangularize(A)(parent(A), isupper ? 1 : -1)
231+
Ap[diagind(Ap, IndexStyle(Ap))] = @view A[diagind(A, IndexStyle(A))]
232+
return Ap
233+
end
234+
224235
function full!(A::LowerTriangular)
225236
B = A.data
226237
tril!(B)
@@ -553,6 +564,9 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular
553564
return A
554565
end
555566

567+
_triangularize(::UpperOrUnitUpperTriangular) = triu
568+
_triangularize(::LowerOrUnitLowerTriangular) = tril
569+
556570
@inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) =
557571
_triscale!(A, B, C, MulAddMul(alpha, beta))
558572
@inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) =
@@ -812,7 +826,8 @@ function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
812826
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
813827
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
814828
end
815-
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)
829+
+(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) + full(B)
830+
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) + copyto!(similar(parent(B), size(B)), B)
816831

817832
function -(A::UpperTriangular, B::UpperTriangular)
818833
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
@@ -846,7 +861,8 @@ function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
846861
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
847862
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
848863
end
849-
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)
864+
-(A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = full(A) - full(B)
865+
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A), size(A)), A) - copyto!(similar(parent(B), size(B)), B)
850866

851867
# use broadcasting if the parents are strided, where we loop only over the triangular part
852868
for op in (:+, :-)

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ debug && println("Test basic type functionality")
2525
@test_throws DimensionMismatch LowerTriangular(randn(5, 4))
2626
@test LowerTriangular(randn(3, 3)) |> t -> [size(t, i) for i = 1:3] == [size(Matrix(t), i) for i = 1:3]
2727

28+
struct MyTriangular{T, A<:LinearAlgebra.AbstractTriangular{T}} <: LinearAlgebra.AbstractTriangular{T}
29+
data :: A
30+
end
31+
Base.size(A::MyTriangular) = size(A.data)
32+
Base.getindex(A::MyTriangular, i::Int, j::Int) = A.data[i,j]
33+
2834
# The following test block tries to call all methods in base/linalg/triangular.jl in order for a combination of input element types. Keep the ordering when adding code.
2935
@testset for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat}, Int)
3036
# Begin loop for first Triangular matrix
@@ -1078,4 +1084,50 @@ end
10781084
end
10791085
end
10801086

1087+
@testset "addition/subtraction of mixed triangular" begin
1088+
for A in (Hermitian(rand(4, 4)), Diagonal(rand(5)))
1089+
for T in (UpperTriangular, LowerTriangular,
1090+
UnitUpperTriangular, UnitLowerTriangular)
1091+
B = T(A)
1092+
M = Matrix(B)
1093+
R = B - B'
1094+
if A isa Diagonal
1095+
@test R isa Diagonal
1096+
end
1097+
@test R == M - M'
1098+
R = B + B'
1099+
if A isa Diagonal
1100+
@test R isa Diagonal
1101+
end
1102+
@test R == M + M'
1103+
C = MyTriangular(B)
1104+
@test C - C' == M - M'
1105+
@test C + C' == M + M'
1106+
end
1107+
end
1108+
@testset "unfilled parent" begin
1109+
@testset for T in (UpperTriangular, LowerTriangular,
1110+
UnitUpperTriangular, UnitLowerTriangular)
1111+
F = Matrix{BigFloat}(undef, 2, 2)
1112+
B = T(F)
1113+
isupper = B isa Union{UpperTriangular, UnitUpperTriangular}
1114+
B[1+!isupper, 1+isupper] = 2
1115+
if !(B isa Union{UnitUpperTriangular, UnitLowerTriangular})
1116+
B[1,1] = B[2,2] = 3
1117+
end
1118+
M = Matrix(B)
1119+
# These are broken, as triu/tril don't work with
1120+
# unfilled adjoint matrices
1121+
# See https://github.com/JuliaLang/julia/pull/55312
1122+
@test_broken B - B' == M - M'
1123+
@test_broken B + B' == M + M'
1124+
@test B - copy(B') == M - M'
1125+
@test B + copy(B') == M + M'
1126+
C = MyTriangular(B)
1127+
@test C - C' == M - M'
1128+
@test C + C' == M + M'
1129+
end
1130+
end
1131+
end
1132+
10811133
end # module TestTriangular

0 commit comments

Comments
 (0)