Skip to content

Commit ade63a0

Browse files
authored
Merge branch 'master' into dk/diag_svd_eigen
2 parents 530fcee + b464203 commit ade63a0

File tree

16 files changed

+152
-66
lines changed

16 files changed

+152
-66
lines changed

.ci/Manifest.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ version = "1.11.0"
5252

5353
[[deps.JuliaSyntaxHighlighting]]
5454
deps = ["StyledStrings"]
55-
uuid = "dc6e5ff7-fb65-4e79-a425-ec3bc9c03011"
55+
uuid = "ac6e5ff7-fb65-4e79-a425-ec3bc9c03011"
5656
version = "1.12.0"
5757

5858
[[deps.LazyArtifacts]]
@@ -93,7 +93,7 @@ version = "1.11.0"
9393
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
9494
path = ".."
9595
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
96-
version = "1.11.0"
96+
version = "1.12.0"
9797

9898
[[deps.Logging]]
9999
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

src/LinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1616
permutedims, permuterows!, power_by_squaring, promote_rule, real, sec, sech, setindex!,
1717
show, similar, sin, sincos, sinh, size, sqrt, strides, stride, tan, tanh, transpose, trunc,
1818
typed_hcat, vec, view, zero
19+
import Base: AbstractArray, AbstractMatrix, Array, Matrix
1920
using Base: IndexLinear, promote_eltype, promote_op, print_matrix,
2021
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
2122
splat, BitInteger

src/blas.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ export
8484
trsm!,
8585
trsm
8686

87-
using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, checksquare, chkstride1
87+
using ..LinearAlgebra: libblastrampoline, BlasReal, BlasComplex, BlasFloat, BlasInt,
88+
DimensionMismatch, checksquare, chkstride1, SingularException
8889

8990
include("lbt.jl")
9091

@@ -128,6 +129,10 @@ Set the number of threads the BLAS library should use equal to `n::Integer`.
128129
129130
Also accepts `nothing`, in which case julia tries to guess the default number of threads.
130131
Passing `nothing` is discouraged and mainly exists for historical reasons.
132+
133+
!!! note
134+
Some BLAS libraries, such as Apple Accelerate, cannot be configured to use a fixed number of threads.
135+
For these backends, `set_num_threads()` is a no-op. See also [`get_num_threads`](@ref).
131136
"""
132137
set_num_threads(nt::Integer)::Nothing = lbt_set_num_threads(Int32(nt))
133138
function set_num_threads(::Nothing)
@@ -147,6 +152,10 @@ Get the number of threads the BLAS library is using.
147152
148153
!!! compat "Julia 1.6"
149154
`get_num_threads` requires at least Julia 1.6.
155+
156+
!!! note
157+
Some BLAS libraries, such as Apple Accelerate, cannot be configured to use a fixed number of threads.
158+
For these backends, `get_num_threads()` always returns `1`. See also [`set_num_threads`](@ref).
150159
"""
151160
get_num_threads()::Int = lbt_get_num_threads()
152161

@@ -1369,6 +1378,11 @@ for (fname, elty) in ((:dtrsv_,:Float64),
13691378
throw(DimensionMismatch(lazy"size of A is $n != length(x) = $(length(x))"))
13701379
end
13711380
chkstride1(A)
1381+
if diag == 'N'
1382+
for i in 1:n
1383+
iszero(A[i,i]) && throw(SingularException(i))
1384+
end
1385+
end
13721386
px, stx = vec_pointer_stride(x, ArgumentError("input vector with 0 stride is not allowed"))
13731387
GC.@preserve x ccall((@blasfunc($fname), libblastrampoline), Cvoid,
13741388
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{BlasInt},
@@ -2217,6 +2231,11 @@ for (mmname, smname, elty) in
22172231
end
22182232
chkstride1(A)
22192233
chkstride1(B)
2234+
if diag == 'N'
2235+
for i in 1:k
2236+
iszero(A[i,i]) && throw(SingularException(i))
2237+
end
2238+
end
22202239
ccall((@blasfunc($smname), libblastrampoline), Cvoid,
22212240
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ref{UInt8},
22222241
Ref{BlasInt}, Ref{BlasInt}, Ref{$elty}, Ptr{$elty},

src/cholesky.jl

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ julia> C.U
6565
⋅ ⋅ 3.0
6666
6767
julia> C.L
68-
3×3 LowerTriangular{Float64, Matrix{Float64}}:
68+
3×3 LowerTriangular{Float64, Adjoint{Matrix{Float64}}}:
6969
2.0 ⋅ ⋅
7070
6.0 1.0 ⋅
7171
-8.0 5.0 3.0
@@ -305,7 +305,7 @@ function _cholpivoted!(A::AbstractMatrix, ::Type{UpperTriangular}, tol::Real, ch
305305
rTA = real(eltype(A))
306306
# checks
307307
Base.require_one_based_indexing(A)
308-
n = LinearAlgebra.checksquare(A)
308+
n = checksquare(A)
309309
# initialization
310310
piv = collect(1:n)
311311
dots = zeros(rTA, n)
@@ -354,7 +354,7 @@ function _cholpivoted!(A::AbstractMatrix, ::Type{LowerTriangular}, tol::Real, ch
354354
rTA = real(eltype(A))
355355
# checks
356356
Base.require_one_based_indexing(A)
357-
n = LinearAlgebra.checksquare(A)
357+
n = checksquare(A)
358358
# initialization
359359
piv = collect(1:n)
360360
dots = zeros(rTA, n)
@@ -530,7 +530,7 @@ julia> C.U
530530
⋅ ⋅ 3.0
531531
532532
julia> C.L
533-
3×3 LowerTriangular{Float64, Matrix{Float64}}:
533+
3×3 LowerTriangular{Float64, Adjoint{Matrix{Float64}}}:
534534
2.0 ⋅ ⋅
535535
6.0 1.0 ⋅
536536
-8.0 5.0 3.0
@@ -664,30 +664,15 @@ copy(C::CholeskyPivoted) = CholeskyPivoted(copy(C.factors), C.uplo, C.piv, C.ran
664664
size(C::Union{Cholesky, CholeskyPivoted}) = size(C.factors)
665665
size(C::Union{Cholesky, CholeskyPivoted}, d::Integer) = size(C.factors, d)
666666

667-
function _choleskyUfactor(Cfactors, Cuplo)
668-
if Cuplo === 'U'
669-
return UpperTriangular(Cfactors)
670-
else
671-
return copy(LowerTriangular(Cfactors)')
672-
end
673-
end
674-
function _choleskyLfactor(Cfactors, Cuplo)
675-
if Cuplo === 'L'
676-
return LowerTriangular(Cfactors)
677-
else
678-
return copy(UpperTriangular(Cfactors)')
679-
end
680-
end
681-
682667
function getproperty(C::Cholesky, d::Symbol)
683668
Cfactors = getfield(C, :factors)
684669
Cuplo = getfield(C, :uplo)
685670
if d === :U
686-
_choleskyUfactor(Cfactors, Cuplo)
671+
UpperTriangular(Cuplo == 'U' ? Cfactors : Cfactors')
687672
elseif d === :L
688-
_choleskyLfactor(Cfactors, Cuplo)
673+
LowerTriangular(Cuplo == 'L' ? Cfactors : Cfactors')
689674
elseif d === :UL
690-
return (Cuplo === 'U' ? UpperTriangular(Cfactors) : LowerTriangular(Cfactors))
675+
return (Cuplo == 'U' ? UpperTriangular(Cfactors) : LowerTriangular(Cfactors))
691676
else
692677
return getfield(C, d)
693678
end
@@ -704,9 +689,9 @@ function getproperty(C::CholeskyPivoted{T}, d::Symbol) where {T}
704689
Cfactors = getfield(C, :factors)
705690
Cuplo = getfield(C, :uplo)
706691
if d === :U
707-
_choleskyUfactor(Cfactors, Cuplo)
692+
UpperTriangular(Cuplo == 'U' ? Cfactors : Cfactors')
708693
elseif d === :L
709-
_choleskyLfactor(Cfactors, Cuplo)
694+
LowerTriangular(Cuplo == 'L' ? Cfactors : Cfactors')
710695
elseif d === :p
711696
return getfield(C, :piv)
712697
elseif d === :P
@@ -813,7 +798,7 @@ function rdiv!(B::AbstractMatrix, C::Cholesky)
813798
end
814799
end
815800

816-
function LinearAlgebra.rdiv!(B::AbstractMatrix, C::CholeskyPivoted)
801+
function rdiv!(B::AbstractMatrix, C::CholeskyPivoted)
817802
n = size(C, 2)
818803
for i in 1:size(B, 1)
819804
permute!(view(B, i, 1:n), C.piv)
@@ -965,7 +950,7 @@ function lowrankdowndate!(C::Cholesky, v::AbstractVector)
965950
s = conj(v[i]/Aii)
966951
s2 = abs2(s)
967952
if s2 > 1
968-
throw(LinearAlgebra.PosDefException(i))
953+
throw(PosDefException(i))
969954
end
970955
c = sqrt(1 - abs2(s))
971956

src/dense.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,8 @@ sqrt(::AbstractMatrix)
972972
function sqrt(A::AbstractMatrix{T}) where {T<:Union{Real,Complex}}
973973
if checksquare(A) == 0
974974
return copy(float(A))
975-
elseif isdiag(A)
975+
elseif isdiag(A) && (T <: Complex || all(x -> x zero(x), diagview(A)))
976+
# Real Diagonal sqrt requires each diagonal element to be positive
976977
return applydiagonal(sqrt, A)
977978
elseif ishermitian(A)
978979
sqrtHermA = sqrt(Hermitian(A))
@@ -1587,24 +1588,25 @@ factorize(A::Transpose) = transpose(factorize(parent(A)))
15871588
factorize(a::Number) = a # same as how factorize behaves on Diagonal types
15881589

15891590
function getstructure(A::StridedMatrix)
1591+
require_one_based_indexing(A)
15901592
m, n = size(A)
15911593
if m == 1 return A[1] end
15921594
utri = true
15931595
utri1 = true
15941596
herm = true
15951597
sym = true
1596-
for j = 1:n-1, i = j+1:m
1597-
if utri1
1598+
for j = 1:n, i = j:m
1599+
if (j < n) && (i > j) && utri1 # indices are off-diagonal
15981600
if A[i,j] != 0
15991601
utri1 = i == j + 1
16001602
utri = false
16011603
end
16021604
end
16031605
if sym
1604-
sym &= A[i,j] == A[j,i]
1606+
sym &= A[i,j] == transpose(A[j,i])
16051607
end
16061608
if herm
1607-
herm &= A[i,j] == conj(A[j,i])
1609+
herm &= A[i,j] == adjoint(A[j,i])
16081610
end
16091611
if !(utri1|herm|sym) break end
16101612
end
@@ -1617,10 +1619,12 @@ function getstructure(A::StridedMatrix)
16171619
if ltri1
16181620
for i = 1:n-1
16191621
if A[i,i+1] != 0
1620-
ltri &= false
1622+
ltri = false
16211623
break
16221624
end
16231625
end
1626+
else
1627+
ltri = false
16241628
end
16251629
return (utri, utri1, ltri, ltri1, sym, herm)
16261630
end
@@ -1779,6 +1783,10 @@ Condition number of the matrix `M`, computed using the operator `p`-norm. Valid
17791783
"""
17801784
function cond(A::AbstractMatrix, p::Real=2)
17811785
if p == 2
1786+
if isempty(A)
1787+
checksquare(A)
1788+
return zero(real(eigtype(eltype(A))))
1789+
end
17821790
v = svdvals(A)
17831791
maxv = maximum(v)
17841792
return iszero(maxv) ? oftype(real(maxv), Inf) : maxv / minimum(v)

src/exceptions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ struct SingularException <: Exception
2828
info::BlasInt
2929
end
3030

31+
function Base.showerror(io::IO, ex::SingularException)
32+
print(io, "SingularException: matrix is singular; factorization failed. Zero pivot found at index ", ex.info)
33+
end
34+
3135
"""
3236
PosDefException
3337

src/lapack.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,10 +2975,10 @@ for (orglq, orgqr, orgql, orgrq, ormlq, ormqr, ormql, ormrq, gemqrt, elty) in
29752975
mA = size(A, 1)
29762976
k = length(tau)
29772977
if side == 'L' && m != mA
2978-
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $mA"))
2978+
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the first dimension of A, $mA"))
29792979
end
29802980
if side == 'R' && n != mA
2981-
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $m, must equal the second dimension of A, $mA"))
2981+
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the first dimension of A, $mA"))
29822982
end
29832983
if side == 'L' && k > m
29842984
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
@@ -3025,10 +3025,10 @@ for (orglq, orgqr, orgql, orgrq, ormlq, ormqr, ormql, ormrq, gemqrt, elty) in
30253025
mA = size(A, 1)
30263026
k = length(tau)
30273027
if side == 'L' && m != mA
3028-
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $mA"))
3028+
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the first dimension of A, $mA"))
30293029
end
30303030
if side == 'R' && n != mA
3031-
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $m, must equal the second dimension of A, $mA"))
3031+
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the first dimension of A, $mA"))
30323032
end
30333033
if side == 'L' && k > m
30343034
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
@@ -3078,7 +3078,7 @@ for (orglq, orgqr, orgql, orgrq, ormlq, ormqr, ormql, ormrq, gemqrt, elty) in
30783078
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $nA"))
30793079
end
30803080
if side == 'R' && n != nA
3081-
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $m, must equal the second dimension of A, $nA"))
3081+
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the second dimension of A, $nA"))
30823082
end
30833083
if side == 'L' && k > m
30843084
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
@@ -5391,9 +5391,9 @@ for (syev, syevr, syevd, sygvd, elty) in
53915391
W = similar(A, $elty, n)
53925392
ldz = n
53935393
if jobz == 'N'
5394-
Z = similar(A, $elty, ldz, 0)
5394+
Z = similar(A, $elty, 0)
53955395
elseif jobz == 'V'
5396-
Z = similar(A, $elty, ldz, n)
5396+
Z = similar(A, $elty, ldz * n)
53975397
end
53985398
isuppz = similar(A, BlasInt, 2*n)
53995399
work = Vector{$elty}(undef, 1)
@@ -5423,7 +5423,8 @@ for (syev, syevr, syevd, sygvd, elty) in
54235423
resize!(iwork, liwork)
54245424
end
54255425
end
5426-
W[1:m[]], Z[:,1:(jobz == 'V' ? m[] : 0)]
5426+
zm = jobz == 'V' ? m[] : 0
5427+
resize!(W, m[]), reshape(resize!(Z, ldz * zm), ldz, zm)
54275428
end
54285429
syevr!(jobz::AbstractChar, A::AbstractMatrix{$elty}) =
54295430
syevr!(jobz, 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)
@@ -5593,10 +5594,10 @@ for (syev, syevr, syevd, sygvd, elty, relty) in
55935594
W = similar(A, $relty, n)
55945595
if jobz == 'N'
55955596
ldz = 1
5596-
Z = similar(A, $elty, ldz, 0)
5597+
Z = similar(A, $elty, 0)
55975598
elseif jobz == 'V'
55985599
ldz = n
5599-
Z = similar(A, $elty, ldz, n)
5600+
Z = similar(A, $elty, ldz * n)
56005601
end
56015602
isuppz = similar(A, BlasInt, 2*n)
56025603
work = Vector{$elty}(undef, 1)
@@ -5632,7 +5633,8 @@ for (syev, syevr, syevd, sygvd, elty, relty) in
56325633
resize!(iwork, liwork)
56335634
end
56345635
end
5635-
W[1:m[]], Z[:,1:(jobz == 'V' ? m[] : 0)]
5636+
zm = jobz == 'V' ? m[] : 0
5637+
resize!(W, m[]), reshape(resize!(Z, ldz * zm), ldz, zm)
56365638
end
56375639
syevr!(jobz::AbstractChar, A::AbstractMatrix{$elty}) =
56385640
syevr!(jobz, 'A', 'U', A, 0.0, 0.0, 0, 0, -1.0)

src/matmul.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ function _generic_matmatmul_nonadjtrans!(C, A, B, alpha, beta)
10211021
@inbounds for n in axes(B, 2), k in axes(B, 1)
10221022
# Balpha = B[k,n] * alpha, but we skip the multiplication in case isone(alpha)
10231023
Balpha = @stable_muladdmul MulAddMul(alpha, false)(B[k,n])
1024+
!ismissing(Balpha) && iszero(Balpha) && continue
10241025
@simd for m in axes(A, 1)
10251026
C[m,n] = muladd(A[m,k], Balpha, C[m,n])
10261027
end

src/symmetric.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,19 @@ for f in (:+, :-)
703703
end
704704

705705
*(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
706+
# catch a few potential BLAS-cases
707+
function *(A::HermOrSym{<:BlasFloat,<:StridedMatrix}, B::AdjOrTrans{<:BlasFloat,<:StridedMatrix})
708+
T = promote_type(eltype(A), eltype(B))
709+
mul!(similar(B, T, (size(A, 1), size(B, 2))),
710+
convert(AbstractMatrix{T}, A),
711+
copy_oftype(B, T)) # make sure the AdjOrTrans wrapper is resolved
712+
end
713+
function *(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloat,<:StridedMatrix})
714+
T = promote_type(eltype(A), eltype(B))
715+
mul!(similar(B, T, (size(A, 1), size(B, 2))),
716+
copy_oftype(A, T), # make sure the AdjOrTrans wrapper is resolved
717+
convert(AbstractMatrix{T}, B))
718+
end
706719

707720
function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
708721
require_one_based_indexing(x, y)

src/triangular.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,11 +1223,13 @@ function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function,
12231223
end
12241224
end
12251225
# division
1226-
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
1226+
generic_trimatdiv!(C::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVector{T}) where {T<:BlasFloat} =
1227+
BLAS.trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1228+
function generic_trimatdiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
12271229
if stride(C,1) == stride(A,1) == 1
1228-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1230+
BLAS.trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
12291231
else # incompatible with LAPACK
1230-
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
1232+
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12311233
end
12321234
end
12331235
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}

0 commit comments

Comments
 (0)