Skip to content

Commit 3e525a8

Browse files
authored
Speicalize copy! for triangular, and use copy! in ldiv (#1263)
Currently, we specialize `copyto!` for a triangular source. However, this has two branches, depending on whether the axes match. In the branch where the axes do match, we may use `copy!` instead, and specialize this in terms of the internal `_copyto!` which is identical in implementation. We also use `copy!` in `ldiv` instead of `copyto!`. These should be equivalent in most cases, barring one extra axes check in `copy!`, as the fallback method for `copy!` calls `copyto!` internally. However, the advantage comes for triangular matrices, where `copy!` doesn't have branches, as the axes necessarily match. As a consequence, this reduces the TTFX in operations like ```julia julia> using Random, LinearAlgebra julia> A = rand(4,4); julia> @time A \ UpperTriangular(A); 0.598575 seconds (1.28 M allocations: 61.788 MiB, 98.57% compilation time: 3% of which was recompilation) # master 0.487267 seconds (1.01 M allocations: 49.726 MiB, 3.54% gc time, 98.95% compilation time) # this PR ``` ```julia julia> @time UpperTriangular(A) / A; 0.826212 seconds (1.45 M allocations: 71.427 MiB, 16.96% gc time, 84.47% compilation time) # master 0.616258 seconds (1.28 M allocations: 63.135 MiB, 2.65% gc time, 99.19% compilation time) # this PR ``` I've renamed the internal functions as `_copyto!` -> `_copy!` and `copyto_unaliased!` -> `copy_unaliased!`, as these are closer to the meaning of `copy!` than to `copyto!`.
1 parent 763f19f commit 3e525a8

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

src/LinearAlgebra.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module LinearAlgebra
1010
import Base: \, /, //, *, ^, +, -, ==
1111
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
1212
asin, asinh, atan, atanh, axes, big, broadcast, cbrt, ceil, cis, collect, conj, convert,
13-
copy, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
13+
copy, copy!, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
1414
getindex, hcat, getproperty, imag, inv, invpermuterows!, isapprox, isequal, isone, iszero,
1515
IndexStyle, kron, kron!, length, log, map, ndims, one, oneunit, parent, permutecols!,
1616
permutedims, permuterows!, power_by_squaring, promote_rule, real, isreal, sec, sech, setindex!,
@@ -706,9 +706,9 @@ function ldiv(F::Factorization, B::AbstractVecOrMat)
706706

707707
if n > size(B, 1)
708708
# Underdetermined
709-
copyto!(view(BB, 1:m, :), B)
709+
copy!(view(BB, axes(B,1), ntuple(_->:, ndims(B)-1)...), B)
710710
else
711-
copyto!(BB, B)
711+
copy!(BB, B)
712712
end
713713

714714
ldiv!(FF, BB)

src/triangular.jl

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ end
178178
parent(A::UpperOrLowerTriangular) = A.data
179179

180180
# For strided matrices, we may only loop over the filled triangle
181-
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(similar(A), A)
181+
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copy!(similar(A), A)
182182

183183
# then handle all methods that requires specific handling of upper/lower and unit diagonal
184184

@@ -532,30 +532,34 @@ for T in (:UpperOrUnitUpperTriangular, :LowerOrUnitLowerTriangular)
532532
if axes(dest) != axes(U)
533533
@invoke copyto!(dest::AbstractArray, U::AbstractArray)
534534
else
535-
_copyto!(dest, U)
535+
copy!(dest, U)
536536
end
537537
return dest
538538
end
539+
@eval function copy!(dest::$T, U::$T)
540+
axes(dest) == axes(U) || throw(ArgumentError(
541+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
542+
_copy!(dest, U)
543+
end
539544
end
540545

541546
# copy and scale
542547
for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :UnitLowerTriangular))
543-
@eval @inline function _copyto!(A::$T, B::$T)
544-
@boundscheck checkbounds(A, axes(B)...)
548+
@eval @inline function _copy!(A::$T, B::$T)
545549
copytrito!(parent(A), parent(B), uplo_char(A))
546550
return A
547551
end
548-
@eval @inline function _copyto!(A::$UT, B::$T)
552+
@eval @inline function _copy!(A::$UT, B::$T)
549553
for dind in diagind(A, IndexStyle(A))
550554
if A[dind] != B[dind]
551555
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
552556
end
553557
end
554-
_copyto!($T(parent(A)), B)
558+
_copy!($T(parent(A)), B)
555559
return A
556560
end
557561
end
558-
@inline function _copyto!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
562+
@inline function _copy!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
559563
@boundscheck checkbounds(A, axes(B)...)
560564
B2 = Base.unalias(A, B)
561565
Ap = parent(A)
@@ -570,7 +574,7 @@ end
570574
end
571575
return A
572576
end
573-
@inline function _copyto!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
577+
@inline function _copy!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
574578
@boundscheck checkbounds(A, axes(B)...)
575579
B2 = Base.unalias(A, B)
576580
Ap = parent(A)
@@ -595,23 +599,30 @@ _triangularize!(::LowerOrUnitLowerTriangular) = tril!
595599
if axes(dest) != axes(U)
596600
@invoke copyto!(dest::StridedMatrix, U::AbstractArray)
597601
else
598-
_copyto!(dest, U)
602+
copy!(dest, U)
599603
end
600604
return dest
601605
end
602-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
606+
607+
function copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
608+
axes(dest) == axes(U) || throw(ArgumentError(
609+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
610+
_copy!(dest, U)
611+
end
612+
613+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
603614
copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L')
604615
copytrito!(dest, U, U isa UpperOrUnitUpperTriangular ? 'L' : 'U')
605616
return dest
606617
end
607-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
618+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
608619
U2 = Base.unalias(dest, U)
609-
copyto_unaliased!(dest, U2)
620+
copy_unaliased!(dest, U2)
610621
return dest
611622
end
612623
# for strided matrices, we explicitly loop over the arrays to improve cache locality
613624
# This fuses the copytrito! for the two halves
614-
@inline function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
625+
@inline function copy_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
615626
@boundscheck checkbounds(dest, axes(U)...)
616627
isunit = U isa UnitUpperTriangular
617628
for col in axes(dest,2)
@@ -624,7 +635,7 @@ end
624635
end
625636
return dest
626637
end
627-
@inline function copyto_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
638+
@inline function copy_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
628639
@boundscheck checkbounds(dest, axes(L)...)
629640
isunit = L isa UnitLowerTriangular
630641
for col in axes(dest,2)
@@ -645,7 +656,7 @@ Base.@constprop :aggressive function copytrito_triangular!(Bdata, Adata, uplo, u
645656
BLAS.chkuplo(uplo)
646657
LAPACK.lacpy_size_check(size(Bdata), sz)
647658
# only the diagonal is copied in this case
648-
copyto!(diagview(Bdata), diagview(Adata))
659+
copy!(diagview(Bdata), diagview(Adata))
649660
end
650661
return Bdata
651662
end
@@ -1055,15 +1066,17 @@ isunit_char(::UnitUpperTriangular) = 'U'
10551066
isunit_char(::LowerTriangular) = 'N'
10561067
isunit_char(::UnitLowerTriangular) = 'U'
10571068

1069+
_copy_or_copyto!(dest, src) = ndims(dest) == ndims(src) ? copy!(dest, src) : copyto!(dest, src)
1070+
10581071
# generic fallback for AbstractTriangular matrices outside of the four subtypes provided here
10591072
_trimul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVector) =
1060-
lmul!(A, copyto!(C, B))
1073+
lmul!(A, _copy_or_copyto!(C, B))
10611074
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractMatrix) =
1062-
lmul!(A, copyto!(C, B))
1075+
lmul!(A, copy!(C, B))
10631076
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1064-
rmul!(copyto!(C, A), B)
1077+
rmul!(copy!(C, A), B)
10651078
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
1066-
lmul!(A, copyto!(C, B))
1079+
lmul!(A, copy!(C, B))
10671080
# redirect for UpperOrLowerTriangular
10681081
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
10691082
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1124,9 +1137,9 @@ end
11241137
ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = _ldiv!(C, A, B)
11251138
# generic fallback for AbstractTriangular, directs to 2-arg [l/r]div!
11261139
_ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) =
1127-
ldiv!(A, copyto!(C, B))
1140+
ldiv!(A, _copy_or_copyto!(C, B))
11281141
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1129-
rdiv!(copyto!(C, A), B)
1142+
rdiv!(copy!(C, A), B)
11301143
# redirect for UpperOrLowerTriangular to generic_*div!
11311144
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
11321145
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1204,40 +1217,40 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
12041217
elseif p == Inf
12051218
return inv(LAPACK.trcon!('I', $uploc, $isunitc, A.data))
12061219
else # use fallback
1207-
return cond(copyto!(similar(parent(A)), A), p)
1220+
return cond(copy!(similar(parent(A)), A), p)
12081221
end
12091222
end
12101223
end
12111224
end
12121225

12131226
# multiplication
12141227
generic_trimatmul!(c::StridedVector{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, b::AbstractVector{T}) where {T<:BlasFloat} =
1215-
BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copyto!(c, b))
1228+
BLAS.trmv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, c === b ? c : copy!(c, b))
12161229
function generic_trimatmul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractMatrix{T}) where {T<:BlasFloat}
12171230
if stride(C,1) == stride(A,1) == 1
1218-
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
1231+
BLAS.trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copy!(C, B))
12191232
else # incompatible with BLAS
12201233
@invoke generic_trimatmul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12211234
end
12221235
end
12231236
function generic_mattrimul!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
12241237
if stride(C,1) == stride(B,1) == 1
1225-
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1238+
BLAS.trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copy!(C, A))
12261239
else # incompatible with BLAS
12271240
@invoke generic_mattrimul!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12281241
end
12291242
end
12301243
# division
12311244
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
12321245
if stride(C,1) == stride(A,1) == 1
1233-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
1246+
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : _copy_or_copyto!(C, B))
12341247
else # incompatible with LAPACK
12351248
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
12361249
end
12371250
end
12381251
function generic_mattridiv!(C::StridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat}
12391252
if stride(C,1) == stride(B,1) == 1
1240-
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
1253+
BLAS.trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copy!(C, A))
12411254
else # incompatible with BLAS
12421255
@invoke generic_mattridiv!(C::AbstractMatrix, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractMatrix)
12431256
end
@@ -1962,22 +1975,22 @@ function powm!(A0::UpperTriangular, p::Real)
19621975
for i in axes(S,1)
19631976
@inbounds S[i, i] = S[i, i] + 1
19641977
end
1965-
copyto!(Stmp, S)
1978+
copy!(Stmp, S)
19661979
mul!(S, A, c)
19671980
ldiv!(Stmp, S)
19681981

19691982
c = (p - j) / (j4 - 2)
19701983
for i in axes(S,1)
19711984
@inbounds S[i, i] = S[i, i] + 1
19721985
end
1973-
copyto!(Stmp, S)
1986+
copy!(Stmp, S)
19741987
mul!(S, A, c)
19751988
ldiv!(Stmp, S)
19761989
end
19771990
for i in axes(S,1)
19781991
S[i, i] = S[i, i] + 1
19791992
end
1980-
copyto!(Stmp, S)
1993+
copy!(Stmp, S)
19811994
mul!(S, A, -p)
19821995
ldiv!(Stmp, S)
19831996
for i in axes(S,1)
@@ -1987,7 +2000,7 @@ function powm!(A0::UpperTriangular, p::Real)
19872000
blockpower!(A0, S, p/(2^s))
19882001
for m = 1:s
19892002
mul!(Stmp.data, S, S)
1990-
copyto!(S, Stmp)
2003+
copy!(S, Stmp)
19912004
blockpower!(A0, S, p/(2^(s-m)))
19922005
end
19932006
rmul!(S, normA0^p)
@@ -2174,7 +2187,7 @@ function _find_params_log_quasitriu!(A)
21742187
break
21752188
end
21762189
_sqrt_quasitriu!(A isa UpperTriangular ? parent(A) : A, A)
2177-
copyto!(AmI, A)
2190+
copy!(AmI, A)
21782191
for i in axes(AmI,1)
21792192
@inbounds AmI[i,i] -= 1
21802193
end

0 commit comments

Comments
 (0)