Skip to content

Commit 84198f0

Browse files
committed
Speicalize copy! for triangular, and use copy! in ldiv
1 parent 925acef commit 84198f0

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

src/LinearAlgebra.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module LinearAlgebra
1010
import Base: \, /, //, *, ^, +, -, ==
1111
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
1212
asin, asinh, atan, atanh, axes, big, broadcast, cbrt, ceil, cis, collect, conj, convert,
13-
copy, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
13+
copy, copy!, copyto!, copymutable, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor,
1414
getindex, hcat, getproperty, imag, inv, invpermuterows!, isapprox, isequal, isone, iszero,
1515
IndexStyle, kron, kron!, length, log, map, ndims, one, oneunit, parent, permutecols!,
1616
permutedims, permuterows!, power_by_squaring, promote_rule, real, sec, sech, setindex!,
@@ -706,9 +706,9 @@ function ldiv(F::Factorization, B::AbstractVecOrMat)
706706

707707
if n > size(B, 1)
708708
# Underdetermined
709-
copyto!(view(BB, 1:m, :), B)
709+
copy!(view(BB, axes(B,1), ntuple(_->:, ndims(B)-1)...), B)
710710
else
711-
copyto!(BB, B)
711+
copy!(BB, B)
712712
end
713713

714714
ldiv!(FF, BB)

src/triangular.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -527,30 +527,34 @@ for T in (:UpperOrUnitUpperTriangular, :LowerOrUnitLowerTriangular)
527527
if axes(dest) != axes(U)
528528
@invoke copyto!(dest::AbstractArray, U::AbstractArray)
529529
else
530-
_copyto!(dest, U)
530+
copy!(dest, U)
531531
end
532532
return dest
533533
end
534+
@eval function copy!(dest::$T, U::$T)
535+
axes(dest) == axes(U) || throw(ArgumentError(
536+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
537+
_copy!(dest, U)
538+
end
534539
end
535540

536541
# copy and scale
537542
for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :UnitLowerTriangular))
538-
@eval @inline function _copyto!(A::$T, B::$T)
539-
@boundscheck checkbounds(A, axes(B)...)
543+
@eval @inline function _copy!(A::$T, B::$T)
540544
copytrito!(parent(A), parent(B), uplo_char(A))
541545
return A
542546
end
543-
@eval @inline function _copyto!(A::$UT, B::$T)
547+
@eval @inline function _copy!(A::$UT, B::$T)
544548
for dind in diagind(A, IndexStyle(A))
545549
if A[dind] != B[dind]
546550
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
547551
end
548552
end
549-
_copyto!($T(parent(A)), B)
553+
_copy!($T(parent(A)), B)
550554
return A
551555
end
552556
end
553-
@inline function _copyto!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
557+
@inline function _copy!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular)
554558
@boundscheck checkbounds(A, axes(B)...)
555559
B2 = Base.unalias(A, B)
556560
Ap = parent(A)
@@ -565,7 +569,7 @@ end
565569
end
566570
return A
567571
end
568-
@inline function _copyto!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
572+
@inline function _copy!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular)
569573
@boundscheck checkbounds(A, axes(B)...)
570574
B2 = Base.unalias(A, B)
571575
Ap = parent(A)
@@ -590,23 +594,30 @@ _triangularize!(::LowerOrUnitLowerTriangular) = tril!
590594
if axes(dest) != axes(U)
591595
@invoke copyto!(dest::StridedMatrix, U::AbstractArray)
592596
else
593-
_copyto!(dest, U)
597+
copy!(dest, U)
594598
end
595599
return dest
596600
end
597-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular)
601+
602+
function copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
603+
axes(dest) == axes(U) || throw(ArgumentError(
604+
"arrays must have the same axes for copy! (consider using `copyto!`)"))
605+
_copy!(dest, U)
606+
end
607+
608+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular)
598609
copytrito!(dest, parent(U), U isa UpperOrUnitUpperTriangular ? 'U' : 'L')
599610
copytrito!(dest, U, U isa UpperOrUnitUpperTriangular ? 'L' : 'U')
600611
return dest
601612
end
602-
@propagate_inbounds function _copyto!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
613+
@propagate_inbounds function _copy!(dest::StridedMatrix, U::UpperOrLowerTriangular{<:Any, <:StridedMatrix})
603614
U2 = Base.unalias(dest, U)
604-
copyto_unaliased!(dest, U2)
615+
copy_unaliased!(dest, U2)
605616
return dest
606617
end
607618
# for strided matrices, we explicitly loop over the arrays to improve cache locality
608619
# This fuses the copytrito! for the two halves
609-
@inline function copyto_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
620+
@inline function copy_unaliased!(dest::StridedMatrix, U::UpperOrUnitUpperTriangular{<:Any, <:StridedMatrix})
610621
@boundscheck checkbounds(dest, axes(U)...)
611622
isunit = U isa UnitUpperTriangular
612623
for col in axes(dest,2)
@@ -619,7 +630,7 @@ end
619630
end
620631
return dest
621632
end
622-
@inline function copyto_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
633+
@inline function copy_unaliased!(dest::StridedMatrix, L::LowerOrUnitLowerTriangular{<:Any, <:StridedMatrix})
623634
@boundscheck checkbounds(dest, axes(L)...)
624635
isunit = L isa UnitLowerTriangular
625636
for col in axes(dest,2)

0 commit comments

Comments
 (0)