Skip to content

Commit b21f100

Browse files
authored
Make *Triangular handle units (#43972)
1 parent 528949f commit b21f100

File tree

12 files changed

+638
-585
lines changed

12 files changed

+638
-585
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
409409
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
410410
const BiTri = Union{Bidiagonal,Tridiagonal}
411411
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
412+
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
412413
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
413414
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
414415
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@@ -747,39 +748,27 @@ ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMa
747748
\(xA::AdjOrTrans{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(xA) \ B
748749

749750
### Triangular specializations
750-
function \(B::Bidiagonal, U::UpperTriangular)
751-
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
752-
return B.uplo == 'U' ? UpperTriangular(A) : A
753-
end
754-
function \(B::Bidiagonal, U::UnitUpperTriangular)
755-
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
756-
return B.uplo == 'U' ? UpperTriangular(A) : A
757-
end
758-
function \(B::Bidiagonal, L::LowerTriangular)
759-
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
760-
return B.uplo == 'L' ? LowerTriangular(A) : A
751+
for tri in (:UpperTriangular, :UnitUpperTriangular)
752+
@eval function \(B::Bidiagonal, U::$tri)
753+
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
754+
return B.uplo == 'U' ? UpperTriangular(A) : A
755+
end
756+
@eval function \(U::$tri, B::Bidiagonal)
757+
A = ldiv!(_initarray(\, eltype(U), eltype(B), U), U, B)
758+
return B.uplo == 'U' ? UpperTriangular(A) : A
759+
end
761760
end
762-
function \(B::Bidiagonal, L::UnitLowerTriangular)
763-
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
764-
return B.uplo == 'L' ? LowerTriangular(A) : A
761+
for tri in (:LowerTriangular, :UnitLowerTriangular)
762+
@eval function \(B::Bidiagonal, L::$tri)
763+
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
764+
return B.uplo == 'L' ? LowerTriangular(A) : A
765+
end
766+
@eval function \(L::$tri, B::Bidiagonal)
767+
A = ldiv!(_initarray(\, eltype(L), eltype(B), L), L, B)
768+
return B.uplo == 'L' ? LowerTriangular(A) : A
769+
end
765770
end
766771

767-
function \(U::UpperTriangular, B::Bidiagonal)
768-
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
769-
return B.uplo == 'U' ? UpperTriangular(A) : A
770-
end
771-
function \(U::UnitUpperTriangular, B::Bidiagonal)
772-
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
773-
return B.uplo == 'U' ? UpperTriangular(A) : A
774-
end
775-
function \(L::LowerTriangular, B::Bidiagonal)
776-
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
777-
return B.uplo == 'L' ? LowerTriangular(A) : A
778-
end
779-
function \(L::UnitLowerTriangular, B::Bidiagonal)
780-
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
781-
return B.uplo == 'L' ? LowerTriangular(A) : A
782-
end
783772
### Diagonal specialization
784773
function \(B::Bidiagonal, D::Diagonal)
785774
A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D)
@@ -835,38 +824,27 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal})
835824
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)
836825

837826
### Triangular specializations
838-
function /(U::UpperTriangular, B::Bidiagonal)
839-
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
840-
return B.uplo == 'U' ? UpperTriangular(A) : A
841-
end
842-
function /(U::UnitUpperTriangular, B::Bidiagonal)
843-
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
844-
return B.uplo == 'U' ? UpperTriangular(A) : A
845-
end
846-
function /(L::LowerTriangular, B::Bidiagonal)
847-
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
848-
return B.uplo == 'L' ? LowerTriangular(A) : A
849-
end
850-
function /(L::UnitLowerTriangular, B::Bidiagonal)
851-
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
852-
return B.uplo == 'L' ? LowerTriangular(A) : A
853-
end
854-
function /(B::Bidiagonal, U::UpperTriangular)
855-
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
856-
return B.uplo == 'U' ? UpperTriangular(A) : A
857-
end
858-
function /(B::Bidiagonal, U::UnitUpperTriangular)
859-
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
860-
return B.uplo == 'U' ? UpperTriangular(A) : A
861-
end
862-
function /(B::Bidiagonal, L::LowerTriangular)
863-
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
864-
return B.uplo == 'L' ? LowerTriangular(A) : A
827+
for tri in (:UpperTriangular, :UnitUpperTriangular)
828+
@eval function /(U::$tri, B::Bidiagonal)
829+
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
830+
return B.uplo == 'U' ? UpperTriangular(A) : A
831+
end
832+
@eval function /(B::Bidiagonal, U::$tri)
833+
A = _rdiv!(_initarray(/, eltype(B), eltype(U), U), B, U)
834+
return B.uplo == 'U' ? UpperTriangular(A) : A
835+
end
865836
end
866-
function /(B::Bidiagonal, L::UnitLowerTriangular)
867-
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
868-
return B.uplo == 'L' ? LowerTriangular(A) : A
837+
for tri in (:LowerTriangular, :UnitLowerTriangular)
838+
@eval function /(L::$tri, B::Bidiagonal)
839+
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
840+
return B.uplo == 'L' ? LowerTriangular(A) : A
841+
end
842+
@eval function /(B::Bidiagonal, L::$tri)
843+
A = _rdiv!(_initarray(/, eltype(B), eltype(L), L), B, L)
844+
return B.uplo == 'L' ? LowerTriangular(A) : A
845+
end
869846
end
847+
870848
### Diagonal specialization
871849
function /(D::Diagonal, B::Bidiagonal)
872850
A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B)
@@ -886,8 +864,8 @@ end
886864
factorize(A::Bidiagonal) = A
887865
function inv(B::Bidiagonal{T}) where T
888866
n = size(B, 1)
889-
dest = zeros(typeof(oneunit(T)\one(T)), (n, n))
890-
ldiv!(dest, B, Diagonal{typeof(one(T)\one(T))}(I, n))
867+
dest = zeros(typeof(inv(oneunit(T))), (n, n))
868+
ldiv!(dest, B, Diagonal{typeof(one(T)/one(T))}(I, n))
891869
return B.uplo == 'U' ? UpperTriangular(dest) : LowerTriangular(dest)
892870
end
893871

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -907,14 +907,12 @@ sqrt(A::TransposeAbsMat) = transpose(sqrt(parent(A)))
907907

908908
function inv(A::StridedMatrix{T}) where T
909909
checksquare(A)
910-
S = typeof((oneunit(T)*zero(T) + oneunit(T)*zero(T))/oneunit(T))
911-
AA = convert(AbstractArray{S}, A)
912-
if istriu(AA)
913-
Ai = triu!(parent(inv(UpperTriangular(AA))))
914-
elseif istril(AA)
915-
Ai = tril!(parent(inv(LowerTriangular(AA))))
910+
if istriu(A)
911+
Ai = triu!(parent(inv(UpperTriangular(A))))
912+
elseif istril(A)
913+
Ai = tril!(parent(inv(LowerTriangular(A))))
916914
else
917-
Ai = inv!(lu(AA))
915+
Ai = inv!(lu(A))
918916
Ai = convert(typeof(parent(Ai)), Ai)
919917
end
920918
return Ai

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,23 @@ function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0})
375375
return out
376376
end
377377

378-
function _mul!(out, A, B, _add)
378+
function _mul_diag!(out, A, B, _add)
379379
_muldiag_size_check(out, A, B)
380380
__muldiag!(out, A, B, _add)
381381
return out
382382
end
383383

384+
_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) =
385+
_mul_diag!(out, D, V, _add)
386+
_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) =
387+
_mul_diag!(out, D, B, _add)
388+
_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) =
389+
_mul_diag!(out, A, D, _add)
390+
_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) =
391+
_mul_diag!(C, Da, Db, _add)
392+
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
393+
_mul_diag!(C, Da, Db, _add)
394+
384395
function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
385396
_muldiag_size_check(Da, A)
386397
_muldiag_size_check(A, Db)
@@ -395,6 +406,7 @@ end
395406

396407
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D)
397408
/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D)
409+
398410
rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
399411
# avoid copy when possible via internal 3-arg backend
400412
function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
@@ -557,22 +569,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
557569
# 3-arg ldiv!
558570
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
559571
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
560-
# 3-arg mul!: invoke 5-arg mul! rather than lmul!
561-
@eval mul!(C::$Tri, A::Union{$Tri,$UTri}, D::Diagonal) = mul!(C, A, D, true, false)
572+
# 3-arg mul! is disambiguated in special.jl
562573
# 5-arg mul!
563574
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
564-
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add)
575+
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
565576
α, β = _add.alpha, _add.beta
566577
iszero(α) && return _rmul_or_fill!(C, β)
567-
diag′ = iszero(β) ? nothing : diag(C)
578+
diag′ = bis0 ? nothing : diag(C)
568579
data = mul!(C.data, D, A.data, α, β)
569580
$Tri(_setdiag!(data, _add, D.diag, diag′))
570581
end
571582
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
572-
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add)
583+
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
573584
α, β = _add.alpha, _add.beta
574585
iszero(α) && return _rmul_or_fill!(C, β)
575-
diag′ = iszero(β) ? nothing : diag(C)
586+
diag′ = bis0 ? nothing : diag(C)
576587
data = mul!(C.data, A.data, D, α, β)
577588
$Tri(_setdiag!(data, _add, D.diag, diag′))
578589
end

stdlib/LinearAlgebra/src/hessenberg.jl

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,41 +129,29 @@ for T = (:Number, :UniformScaling, :Diagonal)
129129
end
130130

131131
function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
132-
T = typeof(oneunit(eltype(H))*oneunit(eltype(U)))
133-
HH = copy_similar(H, T)
134-
rmul!(HH, U)
132+
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
135133
UpperHessenberg(HH)
136134
end
137135
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
138-
T = typeof(oneunit(eltype(H))*oneunit(eltype(U)))
139-
HH = copy_similar(H, T)
140-
lmul!(U, HH)
136+
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
141137
UpperHessenberg(HH)
142138
end
143139

144140
function /(H::UpperHessenberg, U::UpperTriangular)
145-
T = typeof(oneunit(eltype(H))/oneunit(eltype(U)))
146-
HH = copy_similar(H, T)
147-
rdiv!(HH, U)
141+
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
148142
UpperHessenberg(HH)
149143
end
150144
function /(H::UpperHessenberg, U::UnitUpperTriangular)
151-
T = typeof(oneunit(eltype(H))/oneunit(eltype(U)))
152-
HH = copy_similar(H, T)
153-
rdiv!(HH, U)
145+
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
154146
UpperHessenberg(HH)
155147
end
156148

157149
function \(U::UpperTriangular, H::UpperHessenberg)
158-
T = typeof(oneunit(eltype(U))\oneunit(eltype(H)))
159-
HH = copy_similar(H, T)
160-
ldiv!(U, HH)
150+
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
161151
UpperHessenberg(HH)
162152
end
163153
function \(U::UnitUpperTriangular, H::UpperHessenberg)
164-
T = typeof(oneunit(eltype(U))\oneunit(eltype(H)))
165-
HH = copy_similar(H, T)
166-
ldiv!(U, HH)
154+
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
167155
UpperHessenberg(HH)
168156
end
169157

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,14 @@ julia> C
265265
730.0 740.0
266266
```
267267
"""
268-
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat,
269-
alpha::Number, beta::Number) =
268+
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
270269
generic_matmatmul!(
271270
C,
272271
adj_or_trans_char(A),
273272
adj_or_trans_char(B),
274273
_parent(A),
275274
_parent(B),
276-
MulAddMul(alpha, beta)
275+
MulAddMul(α, β)
277276
)
278277

279278
"""

stdlib/LinearAlgebra/src/special.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ for op in (:+, :-)
107107
end
108108
end
109109

110+
# disambiguation between triangular and banded matrices, banded ones "dominate"
111+
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix) = _mul!(C, A, B, MulAddMul())
112+
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular) = _mul!(C, A, B, MulAddMul())
113+
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
114+
_mul!(C, A, B, MulAddMul(alpha, beta))
115+
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
116+
_mul!(C, A, B, MulAddMul(alpha, beta))
117+
110118
function *(H::UpperHessenberg, B::Bidiagonal)
111119
T = promote_op(matprod, eltype(H), eltype(B))
112120
A = mul!(similar(H, T, size(H)), H, B)

0 commit comments

Comments
 (0)