Skip to content

Commit 68b3f0f

Browse files
authored
Merge branch 'master' into sgj/evalpoly
2 parents 374d8fe + 61e444d commit 68b3f0f

File tree

10 files changed

+105
-68
lines changed

10 files changed

+105
-68
lines changed

src/LinearAlgebra.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module LinearAlgebra
1010
import Base: \, /, //, *, ^, +, -, ==
1111
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
1212
asin, asinh, atan, atanh, axes, big, broadcast, cbrt, ceil, cis, collect, conj, convert,
13-
copy, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
13+
copy, copy!, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
1414
getindex, hcat, getproperty, imag, inv, invpermuterows!, isapprox, isequal, isone, iszero,
1515
IndexStyle, kron, kron!, length, log, map, ndims, one, oneunit, parent, permutecols!,
1616
permutedims, permuterows!, power_by_squaring, promote_rule, real, isreal, sec, sech, setindex!,
@@ -708,9 +708,9 @@ function ldiv(F::Factorization, B::AbstractVecOrMat)
708708

709709
if n > size(B, 1)
710710
# Underdetermined
711-
copyto!(view(BB, 1:m, :), B)
711+
copy!(view(BB, axes(B,1), ntuple(_->:, ndims(B)-1)...), B)
712712
else
713-
copyto!(BB, B)
713+
copy!(BB, B)
714714
end
715715

716716
ldiv!(FF, BB)
@@ -790,31 +790,22 @@ function versioninfo(io::IO=stdout)
790790
println(io, indent, "LinearAlgebra.BLAS.get_num_threads() = ", BLAS.get_num_threads())
791791
println(io, "Relevant environment variables:")
792792
env_var_names = [
793-
"JULIA_NUM_THREADS",
794-
"MKL_DYNAMIC",
795-
"MKL_NUM_THREADS",
793+
["JULIA_NUM_THREADS"],
794+
["MKL_DYNAMIC"],
795+
["MKL_NUM_THREADS"],
796796
# OpenBLAS has a hierarchy of environment variables for setting the
797797
# number of threads, see
798798
# https://github.com/xianyi/OpenBLAS/blob/c43ec53bdd00d9423fc609d7b7ecb35e7bf41b85/README.md#setting-the-number-of-threads-using-environment-variables
799-
("OPENBLAS_NUM_THREADS", "GOTO_NUM_THREADS", "OMP_NUM_THREADS"),
799+
["OPENBLAS_NUM_THREADS", "GOTO_NUM_THREADS", "OMP_NUM_THREADS"],
800800
]
801801
printed_at_least_one_env_var = false
802802
print_var(io, indent, name) = println(io, indent, name, " = ", ENV[name])
803803
for name in env_var_names
804-
if name isa Tuple
805-
# If `name` is a Tuple, then find the first environment which is
806-
# defined, and disregard the following ones.
807-
for nm in name
808-
if haskey(ENV, nm)
809-
print_var(io, indent, nm)
810-
printed_at_least_one_env_var = true
811-
break
812-
end
813-
end
814-
else
815-
if haskey(ENV, name)
816-
print_var(io, indent, name)
804+
for nm in name
805+
if haskey(ENV, nm)
806+
print_var(io, indent, nm)
817807
printed_at_least_one_env_var = true
808+
break
818809
end
819810
end
820811
end

src/bidiag.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,9 @@ function lmul!(D::Diagonal, B::Bidiagonal)
538538
matmul_size_check(size(D), size(B))
539539
(; dv, ev) = B
540540
isL = B.uplo == 'L'
541-
dv[1] = D.diag[1] * dv[1]
542-
for i in axes(ev,1)
541+
iszero(size(D,1)) && return B
542+
@inbounds dv[1] = D.diag[1] * dv[1]
543+
@inbounds for i in axes(ev,1)
543544
ev[i] = D.diag[i + isL] * ev[i]
544545
dv[i+1] = D.diag[i+1] * dv[i+1]
545546
end
@@ -575,8 +576,9 @@ function rmul!(B::Bidiagonal, D::Diagonal)
575576
matmul_size_check(size(B), size(D))
576577
(; dv, ev) = B
577578
isU = B.uplo == 'U'
578-
dv[1] *= D.diag[1]
579-
for i in axes(ev,1)
579+
iszero(size(D,1)) && return B
580+
@inbounds dv[1] *= D.diag[1]
581+
@inbounds for i in axes(ev,1)
580582
ev[i] *= D.diag[i + isU]
581583
dv[i+1] *= D.diag[i+1]
582584
end

src/dense.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ julia> ℯ^[1 2; 0 3]
720720
0.0 20.0855
721721
```
722722
"""
723-
Base.:^(b::Number, A::AbstractMatrix) = exp!(log(b)*A)
723+
Base.:^(b::Number, A::AbstractMatrix) = exp_maybe_inplace(log(b)*A)
724724
# method for ℯ to explicitly elide the log(b) multiplication
725725
Base.:^(::Irrational{:ℯ}, A::AbstractMatrix) = exp(A)
726726

src/diagonal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ end
191191
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
192192
"""
193193
diagzero(A::AbstractMatrix, i, j) = zero(eltype(A))
194-
diagzero(A::AbstractMatrix{M}, i, j) where {M<:AbstractMatrix} =
194+
@propagate_inbounds diagzero(A::AbstractMatrix{M}, i, j) where {M<:AbstractMatrix} =
195195
zeroslike(M, axes(A[i,i], 1), axes(A[j,j], 2))
196196
diagzero(A::AbstractMatrix, inds...) = diagzero(A, to_indices(A, inds)...)
197197
# dispatching on the axes permits specializing on the axis types to return something other than an Array

src/matmul.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add
725725

726726
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
727727
# to be concretely inferred
728-
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
729-
α::Number, β::Number) where {T<:BlasReal}
728+
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC},
729+
α::Number, β::Number) where {TC<:BlasComplex}
730+
T = real(TC)
730731
nC = checksquare(C)
731732
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
732733
if tA_uc == 'C'
@@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
740741
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
741742
end
742743

743-
# Result array does not need to be initialized as long as beta==0
744-
# C = Matrix{T}(undef, mA, mA)
745-
744+
# BLAS.herk! only updates hermitian C, alpha and beta need to be real
746745
if iszero(β) || ishermitian(C)
747746
alpha, beta = promote(α, β, zero(T))
748-
if (alpha isa Union{Bool,T} &&
749-
beta isa Union{Bool,T} &&
747+
if (alpha isa T && beta isa T &&
750748
stride(A, 1) == stride(C, 1) == 1 &&
751749
_fullstride2(A) && _fullstride2(C))
752750
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)

src/triangular.jl

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ end
178178
parent(A::UpperOrLowerTriangular) = A.data
179179

180180
# For strided matrices, we may only loop over the filled triangle
181-
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(similar(A), A)
181+
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copy!(similar(A), A)
182182

183183
# then handle all methods that requires specific handling of upper/lower and unit diagonal
184184

@@ -532,30 +532,34 @@ for T in (:UpperOrUnitUpperTriangular, :LowerOrUnitLowerTriangular)
532532
if axes(dest) != axes(U)
533533
@invoke copyto!(dest::AbstractArray, U::AbstractArray)
534534
else
535-
_copyto!(dest, U)
535+
copy!(dest, U)
536536
end
537537
return dest
538538
end
539+
@eval function copy!(dest::$T, U::$T)
540+
axes(dest) == axes(U) || throw(ArgumentError(
541+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
542+
_copy!(dest, U)
543+
end
539544
end
540545

541546
# copy and scale
542547
for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :UnitLowerTriangular))
543-
@eval @inline function _copyto!(A::$T, B::$T)
544-
@boundscheck checkbounds(A, axes(B)...)
548+
@eval @inline function _copy!(A::$T, B::$T)
545549
copytrito!(parent(A), parent(B), uplo_char(A))
546550
return A
547551
end
548-
@eval @inline function _copyto!(A::$UT, B::$T)
552+
@eval @inline function _copy!(A::$UT, B::$T)
549553
for dind in diagind(A, IndexStyle(A))
550554
if A[dind] != B[dind]
551555
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
552556
end
553557
end
554-
_copyto!($T(parent(A)), B)
558+
_copy!($T(parent(A)), B)
555559
return A
556560
end
557561
end
558-
@inline function _copyto!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
562+
@inline function _copy!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
559563
@boundscheck checkbounds(A, axes(B)...)
560564
B2 = Base.unalias(A, B)
561565
Ap = parent(A)
@@ -570,7 +574,7 @@ end
570574
end
571575
return A
572576
end
573-
@inline function _copyto!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
577+
@inline function _copy!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
574578
@boundscheck checkbounds(A, axes(B)...)
575579
B2 = Base.unalias(A, B)
576580
Ap = parent(A)
@@ -595,23 +599,30 @@ _triangularize!(::LowerOrUnitLowerTriangular) = tril!
595599
if axes(dest) != axes(U)
596600
@invoke copyto!(dest::StridedMatrix, U::AbstractArray)
597601
else
598-
_copyto!(dest, U)
602+
copy!(dest, U)
599603
end
600604
return dest
601605
end
602-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
606+
607+
function copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
608+
axes(dest) == axes(U) || throw(ArgumentError(
609+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
610+
_copy!(dest, U)
611+
end
612+
613+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
603614
copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L')
604615
copytrito!(dest, U, U isa UpperOrUnitUpperTriangular ? 'L' : 'U')
605616
return dest
606617
end
607-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
618+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
608619
U2 = Base.unalias(dest, U)
609-
copyto_unaliased!(dest, U2)
620+
copy_unaliased!(dest, U2)
610621
return dest
611622
end
612623
# for strided matrices, we explicitly loop over the arrays to improve cache locality
613624
# This fuses the copytrito! for the two halves
614-
@inline function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
625+
@inline function copy_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
615626
@boundscheck checkbounds(dest, axes(U)...)
616627
isunit = U isa UnitUpperTriangular
617628
for col in axes(dest,2)
@@ -624,7 +635,7 @@ end
624635
end
625636
return dest
626637
end
627-
@inline function copyto_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
638+
@inline function copy_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
628639
@boundscheck checkbounds(dest, axes(L)...)
629640
isunit = L isa UnitLowerTriangular
630641
for col in axes(dest,2)
@@ -645,7 +656,7 @@ Base.@constprop :aggressive function copytrito_triangular!(Bdata, Adata, uplo, u
645656
BLAS.chkuplo(uplo)
646657
LAPACK.lacpy_size_check(size(Bdata), sz)
647658
# only the diagonal is copied in this case
648-
copyto!(diagview(Bdata), diagview(Adata))
659+
copy!(diagview(Bdata), diagview(Adata))
649660
end
650661
return Bdata
651662
end
@@ -1055,15 +1066,17 @@ isunit_char(::UnitUpperTriangular) = 'U'
10551066
isunit_char(::LowerTriangular) = 'N'
10561067
isunit_char(::UnitLowerTriangular) = 'U'
10571068

1069+
_copy_or_copyto!(dest, src) = ndims(dest) == ndims(src) ? copy!(dest, src) : copyto!(dest, src)
1070+
10581071
# generic fallback for AbstractTriangular matrices outside of the four subtypes provided here
10591072
_trimul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVector) =
1060-
lmul!(A, copyto!(C, B))
1073+
lmul!(A, _copy_or_copyto!(C, B))
10611074
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractMatrix) =
1062-
lmul!(A, copyto!(C, B))
1075+
lmul!(A, copy!(C, B))
10631076
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1064-
rmul!(copyto!(C, A), B)
1077+
rmul!(copy!(C, A), B)
10651078
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
1066-
lmul!(A, copyto!(C, B))
1079+
lmul!(A, copy!(C, B))
10671080
# redirect for UpperOrLowerTriangular
10681081
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
10691082
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1124,9 +1137,9 @@ end
11241137
ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = _ldiv!(C, A, B)
11251138
# generic fallback for AbstractTriangular, directs to 2-arg [l/r]div!
11261139
_ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) =
1127-
ldiv!(A, copyto!(C, B))
1140+
ldiv!(A, _copy_or_copyto!(C, B))
11281141
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1129-
rdiv!(copyto!(C, A), B)
1142+
rdiv!(copy!(C, A), B)
11301143
# redirect for UpperOrLowerTriangular to generic_*div!
11311144
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
11321145
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1204,40 +1217,40 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
12041217
elseif p == Inf
12051218
return inv(LAPACK.trcon!('I', $uploc, $isunitc, A.data))
12061219
else # use fallback
1207-
return cond(copyto!(similar(parent(A)), A), p)
1220+
return cond(copy!(similar(parent(A)), A), p)
12081221
end
12091222
end
12101223
end
12111224
end
12121225

12131226
# multiplication
12141227
generic_trimatmul!(c::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, b::AbstractVector{T}) where {T<:BlasFloat} =
1215-
BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b))
1228+
BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copy!(c, b))
12161229
function generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
12171230
if stride(C,1) == stride(A,1) == 1
1218-
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
1231+
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copy!(C, B))
12191232
else # incompatible with BLAS
12201233
@invoke generic_trimatmul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12211234
end
12221235
end
12231236
function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
12241237
if stride(C,1) == stride(B,1) == 1
1225-
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1238+
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copy!(C, A))
12261239
else # incompatible with BLAS
12271240
@invoke generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12281241
end
12291242
end
12301243
# division
12311244
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
12321245
if stride(C,1) == stride(A,1) == 1
1233-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1246+
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : _copy_or_copyto!(C, B))
12341247
else # incompatible with LAPACK
12351248
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
12361249
end
12371250
end
12381251
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
12391252
if stride(C,1) == stride(B,1) == 1
1240-
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1253+
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copy!(C, A))
12411254
else # incompatible with BLAS
12421255
@invoke generic_mattridiv!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12431256
end
@@ -1962,22 +1975,22 @@ function powm!(A0::UpperTriangular, p::Real)
19621975
for i in axes(S,1)
19631976
@inbounds S[i, i] = S[i, i] + 1
19641977
end
1965-
copyto!(Stmp, S)
1978+
copy!(Stmp, S)
19661979
mul!(S, A, c)
19671980
ldiv!(Stmp, S)
19681981

19691982
c = (p - j) / (j4 - 2)
19701983
for i in axes(S,1)
19711984
@inbounds S[i, i] = S[i, i] + 1
19721985
end
1973-
copyto!(Stmp, S)
1986+
copy!(Stmp, S)
19741987
mul!(S, A, c)
19751988
ldiv!(Stmp, S)
19761989
end
19771990
for i in axes(S,1)
19781991
S[i, i] = S[i, i] + 1
19791992
end
1980-
copyto!(Stmp, S)
1993+
copy!(Stmp, S)
19811994
mul!(S, A, -p)
19821995
ldiv!(Stmp, S)
19831996
for i in axes(S,1)
@@ -1987,7 +2000,7 @@ function powm!(A0::UpperTriangular, p::Real)
19872000
blockpower!(A0, S, p/(2^s))
19882001
for m = 1:s
19892002
mul!(Stmp.data, S, S)
1990-
copyto!(S, Stmp)
2003+
copy!(S, Stmp)
19912004
blockpower!(A0, S, p/(2^(s-m)))
19922005
end
19932006
rmul!(S, normA0^p)
@@ -2174,7 +2187,7 @@ function _find_params_log_quasitriu!(A)
21742187
break
21752188
end
21762189
_sqrt_quasitriu!(A isa UpperTriangular ? parent(A) : A, A)
2177-
copyto!(AmI, A)
2190+
copy!(AmI, A)
21782191
for i in axes(AmI,1)
21792192
@inbounds AmI[i,i] -= 1
21802193
end

test/bidiag.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,4 +1192,13 @@ end
11921192
@test_throws msg ldiv!(C, B, zeros(2,1))
11931193
end
11941194

1195+
@testset "l/rmul with 0-sized matrices" begin
1196+
n = 0
1197+
B = Bidiagonal(ones(n), ones(max(n-1,0)), :U)
1198+
B2 = copy(B)
1199+
D = Diagonal(ones(n))
1200+
@test lmul!(D, B) == B2
1201+
@test rmul!(B, D) == B2
1202+
end
1203+
11951204
end # module TestBidiagonal

0 commit comments

Comments
 (0)