Skip to content

Commit 7212c03

Browse files
authored
Clean-up Bidiagonal mul/solve code (#47223)
1 parent f70b5e4 commit 7212c03

File tree

4 files changed

+86
-76
lines changed

4 files changed

+86
-76
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,25 @@ _makevector(x::AbstractVector) = Vector(x)
478478
_pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B)
479479
_droplast!(A) = deleteat!(A, lastindex(A))
480480

481+
# some trait like this would be cool
482+
# onedefined(::Type{T}) where {T} = hasmethod(one, (T,))
483+
# but we are actually asking for oneunit(T), that is, however, defined for generic T as
484+
# `T(one(T))`, so the question is equivalent for whether one(T) is defined
485+
onedefined(::Type) = false
486+
onedefined(::Type{<:Number}) = true
487+
488+
# initialize return array for op(A, B)
489+
_init_eltype(::typeof(*), ::Type{TA}, ::Type{TB}) where {TA,TB} =
490+
(onedefined(TA) && onedefined(TB)) ?
491+
typeof(matprod(oneunit(TA), oneunit(TB))) :
492+
promote_op(matprod, TA, TB)
493+
_init_eltype(op, ::Type{TA}, ::Type{TB}) where {TA,TB} =
494+
(onedefined(TA) && onedefined(TB)) ?
495+
typeof(op(oneunit(TA), oneunit(TB))) :
496+
promote_op(op, TA, TB)
497+
_initarray(op, ::Type{TA}, ::Type{TB}, C) where {TA,TB} =
498+
similar(C, _init_eltype(op, TA, TB), size(C))
499+
481500
# General fallback definition for handling under- and overdetermined system as well as square problems
482501
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
483502
# which is required by LAPACK but not SuiteSparse which allows real-complex solves in some cases. Hence,

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -783,49 +783,47 @@ ldiv!(c::AbstractVecOrMat, A::Adjoint{<:Any,<:Bidiagonal}, b::AbstractVecOrMat)
783783
(_rdiv!(adjoint(c), adjoint(b), adjoint(A)); return c)
784784

785785
### Generic promotion methods and fallbacks
786-
function \(A::Bidiagonal{<:Number}, B::AbstractVecOrMat{<:Number})
787-
TA, TB = eltype(A), eltype(B)
788-
TAB = typeof((oneunit(TA))\oneunit(TB))
789-
ldiv!(zeros(TAB, size(B)), A, B)
790-
end
791-
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(copy(B), A, B)
786+
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
792787
\(tA::Transpose{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(tA) \ B
793788
\(adjA::Adjoint{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(adjA) \ B
794789

795790
### Triangular specializations
796-
function \(B::Bidiagonal{<:Number}, U::UpperOrUnitUpperTriangular{<:Number})
797-
T = typeof((oneunit(eltype(B)))\oneunit(eltype(U)))
798-
A = ldiv!(zeros(T, size(U)), B, U)
791+
function \(B::Bidiagonal, U::UpperTriangular)
792+
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
799793
return B.uplo == 'U' ? UpperTriangular(A) : A
800794
end
801-
function \(B::Bidiagonal, U::UpperOrUnitUpperTriangular)
802-
A = ldiv!(copy(parent(U)), B, U)
795+
function \(B::Bidiagonal, U::UnitUpperTriangular)
796+
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
803797
return B.uplo == 'U' ? UpperTriangular(A) : A
804798
end
805-
function \(B::Bidiagonal{<:Number}, L::LowerOrUnitLowerTriangular{<:Number})
806-
T = typeof((oneunit(eltype(B)))\oneunit(eltype(L)))
807-
A = ldiv!(zeros(T, size(L)), B, L)
799+
function \(B::Bidiagonal, L::LowerTriangular)
800+
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
808801
return B.uplo == 'L' ? LowerTriangular(A) : A
809802
end
810-
function \(B::Bidiagonal, L::LowerOrUnitLowerTriangular)
811-
A = ldiv!(copy(parent(L)), B, L)
803+
function \(B::Bidiagonal, L::UnitLowerTriangular)
804+
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
812805
return B.uplo == 'L' ? LowerTriangular(A) : A
813806
end
814807

815-
function \(U::UpperOrUnitUpperTriangular{<:Number}, B::Bidiagonal{<:Number})
816-
T = typeof((oneunit(eltype(U)))/oneunit(eltype(B)))
817-
A = ldiv!(U, copy_similar(B, T))
808+
function \(U::UpperTriangular, B::Bidiagonal)
809+
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
810+
return B.uplo == 'U' ? UpperTriangular(A) : A
811+
end
812+
function \(U::UnitUpperTriangular, B::Bidiagonal)
813+
A = ldiv!(U, copy_similar(B, _init_eltype(\, eltype(U), eltype(B))))
818814
return B.uplo == 'U' ? UpperTriangular(A) : A
819815
end
820-
function \(L::LowerOrUnitLowerTriangular{<:Number}, B::Bidiagonal{<:Number})
821-
T = typeof((oneunit(eltype(L)))/oneunit(eltype(B)))
822-
A = ldiv!(L, copy_similar(B, T))
816+
function \(L::LowerTriangular, B::Bidiagonal)
817+
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
818+
return B.uplo == 'L' ? LowerTriangular(A) : A
819+
end
820+
function \(L::UnitLowerTriangular, B::Bidiagonal)
821+
A = ldiv!(L, copy_similar(B, _init_eltype(\, eltype(L), eltype(B))))
823822
return B.uplo == 'L' ? LowerTriangular(A) : A
824823
end
825824
### Diagonal specialization
826-
function \(B::Bidiagonal{<:Number}, D::Diagonal{<:Number})
827-
T = typeof((oneunit(eltype(B)))\oneunit(eltype(D)))
828-
A = ldiv!(zeros(T, size(D)), B, D)
825+
function \(B::Bidiagonal, D::Diagonal)
826+
A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D)
829827
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
830828
end
831829

@@ -878,54 +876,50 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::Adjoint{<:Any,<:Bidiagonal}) =
878876
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::Transpose{<:Any,<:Bidiagonal}) =
879877
(ldiv!(transpose(C), transpose(B), transpose(A)); return C)
880878

881-
function /(A::AbstractMatrix{<:Number}, B::Bidiagonal{<:Number})
882-
TA, TB = eltype(A), eltype(B)
883-
TAB = typeof((oneunit(TA))/oneunit(TB))
884-
_rdiv!(zeros(TAB, size(A)), A, B)
885-
end
886-
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(copy(A), A, B)
879+
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)
887880

888881
### Triangular specializations
889-
function /(U::UpperOrUnitUpperTriangular{<:Number}, B::Bidiagonal{<:Number})
890-
T = typeof((oneunit(eltype(U)))/oneunit(eltype(B)))
891-
A = _rdiv!(zeros(T, size(U)), U, B)
882+
function /(U::UpperTriangular, B::Bidiagonal)
883+
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
892884
return B.uplo == 'U' ? UpperTriangular(A) : A
893885
end
894-
function /(U::UpperOrUnitUpperTriangular, B::Bidiagonal)
895-
A = _rdiv!(copy(parent(U)), U, B)
886+
function /(U::UnitUpperTriangular, B::Bidiagonal)
887+
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
896888
return B.uplo == 'U' ? UpperTriangular(A) : A
897889
end
898-
function /(L::LowerOrUnitLowerTriangular{<:Number}, B::Bidiagonal{<:Number})
899-
T = typeof((oneunit(eltype(L)))/oneunit(eltype(B)))
900-
A = _rdiv!(zeros(T, size(L)), L, B)
890+
function /(L::LowerTriangular, B::Bidiagonal)
891+
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
901892
return B.uplo == 'L' ? LowerTriangular(A) : A
902893
end
903-
function /(L::LowerOrUnitLowerTriangular, B::Bidiagonal)
904-
A = _rdiv!(copy(parent(L)), L, B)
894+
function /(L::UnitLowerTriangular, B::Bidiagonal)
895+
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
905896
return B.uplo == 'L' ? LowerTriangular(A) : A
906897
end
907-
function /(B::Bidiagonal{<:Number}, U::UpperOrUnitUpperTriangular{<:Number})
908-
T = typeof((oneunit(eltype(B)))/oneunit(eltype(U)))
909-
A = rdiv!(copy_similar(B, T), U)
898+
function /(B::Bidiagonal, U::UpperTriangular)
899+
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
900+
return B.uplo == 'U' ? UpperTriangular(A) : A
901+
end
902+
function /(B::Bidiagonal, U::UnitUpperTriangular)
903+
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(U))), U)
910904
return B.uplo == 'U' ? UpperTriangular(A) : A
911905
end
912-
function /(B::Bidiagonal{<:Number}, L::LowerOrUnitLowerTriangular{<:Number})
913-
T = typeof((oneunit(eltype(B)))\oneunit(eltype(L)))
914-
A = rdiv!(copy_similar(B, T), L)
906+
function /(B::Bidiagonal, L::LowerTriangular)
907+
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
908+
return B.uplo == 'L' ? LowerTriangular(A) : A
909+
end
910+
function /(B::Bidiagonal, L::UnitLowerTriangular)
911+
A = rdiv!(copy_similar(B, _init_eltype(/, eltype(B), eltype(L))), L)
915912
return B.uplo == 'L' ? LowerTriangular(A) : A
916913
end
917914
### Diagonal specialization
918-
function /(D::Diagonal{<:Number}, B::Bidiagonal{<:Number})
919-
T = typeof((oneunit(eltype(D)))/oneunit(eltype(B)))
920-
A = _rdiv!(zeros(T, size(D)), D, B)
915+
function /(D::Diagonal, B::Bidiagonal)
916+
A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B)
921917
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
922918
end
923919

924920
/(A::AbstractMatrix, B::Transpose{<:Any,<:Bidiagonal}) = A / copy(B)
925921
/(A::AbstractMatrix, B::Adjoint{<:Any,<:Bidiagonal}) = A / copy(B)
926922
# disambiguation
927-
/(A::AdjointAbsVec{<:Number}, B::Bidiagonal{<:Number}) = adjoint(adjoint(B) \ parent(A))
928-
/(A::TransposeAbsVec{<:Number}, B::Bidiagonal{<:Number}) = transpose(transpose(B) \ parent(A))
929923
/(A::AdjointAbsVec, B::Bidiagonal) = adjoint(adjoint(B) \ parent(A))
930924
/(A::TransposeAbsVec, B::Bidiagonal) = transpose(transpose(B) \ parent(A))
931925
/(A::AdjointAbsVec, B::Transpose{<:Any,<:Bidiagonal}) = adjoint(adjoint(B) \ parent(A))

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,8 @@ end
377377
mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
378378
_muldiag!(C, Da, Db, alpha, beta)
379379

380-
_init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
381-
(_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))
382-
_init(op, A::AbstractArray, B::AbstractArray) = promote_op(op, eltype(A), eltype(B))
383-
384-
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(_init(/, A, D).(A), A, D)
380+
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D)
381+
/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D)
385382
rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
386383
# avoid copy when possible via internal 3-arg backend
387384
function _rdiv!(B::AbstractVecOrMat, A::AbstractVecOrMat, D::Diagonal)
@@ -406,8 +403,8 @@ function \(D::Diagonal, B::AbstractVector)
406403
isnothing(j) || throw(SingularException(j))
407404
return D.diag .\ B
408405
end
409-
\(D::Diagonal, B::AbstractMatrix) =
410-
ldiv!(_init(\, D, B).(B), D, B)
406+
\(D::Diagonal, B::AbstractMatrix) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B))), D, B)
407+
\(D::Diagonal, B::HermOrSym) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B)), size(B)), D, B)
411408

412409
ldiv!(D::Diagonal, B::AbstractVecOrMat) = @inline ldiv!(B, D, B)
413410
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)

stdlib/LinearAlgebra/test/testgroups

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1+
addmul
12
triangular
2-
qr
3-
dense
43
matmul
5-
schur
4+
dense
5+
symmetric
6+
diagonal
67
special
7-
eigen
8-
bunchkaufman
9-
svd
10-
lapack
11-
tridiag
128
bidiag
13-
diagonal
9+
qr
1410
cholesky
11+
blas
1512
lu
16-
symmetric
17-
generic
1813
uniformscaling
19-
lq
14+
structuredbroadcast
2015
hessenberg
21-
blas
16+
svd
17+
eigen
18+
tridiag
19+
lapack
20+
lq
2221
adjtrans
23-
pinv
22+
generic
23+
schur
24+
bunchkaufman
2425
givens
25-
structuredbroadcast
26-
addmul
27-
ldlt
26+
pinv
2827
factorization
28+
ldlt

0 commit comments

Comments
 (0)