Skip to content

Commit 62784d8

Browse files
committed
Use copy! at more places
1 parent 4f1a274 commit 62784d8

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

src/triangular.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ end
178178
parent(A::UpperOrLowerTriangular) = A.data
179179

180180
# For strided matrices, we may only loop over the filled triangle
181-
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copyto!(similar(A), A)
181+
copy(A::UpperOrLowerTriangular{<:Any, <:StridedMaybeAdjOrTransMat}) = copy!(similar(A), A)
182182

183183
# then handle all methods that requires specific handling of upper/lower and unit diagonal
184184

@@ -651,7 +651,7 @@ Base.@constprop :aggressive function copytrito_triangular!(Bdata, Adata, uplo, u
651651
BLAS.chkuplo(uplo)
652652
LAPACK.lacpy_size_check(size(Bdata), sz)
653653
# only the diagonal is copied in this case
654-
copyto!(diagview(Bdata), diagview(Adata))
654+
copy!(diagview(Bdata), diagview(Adata))
655655
end
656656
return Bdata
657657
end
@@ -1061,15 +1061,17 @@ isunit_char(::UnitUpperTriangular) = 'U'
10611061
isunit_char(::LowerTriangular) = 'N'
10621062
isunit_char(::UnitLowerTriangular) = 'U'
10631063

1064+
_copy_or_copyto!(dest, src) = ndims(dest) == ndims(src) ? copy!(dest, src) : copyto!(dest, src)
1065+
10641066
# generic fallback for AbstractTriangular matrices outside of the four subtypes provided here
10651067
_trimul!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVector) =
1066-
lmul!(A, copyto!(C, B))
1068+
lmul!(A, _copy_or_copyto!(C, B))
10671069
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractMatrix) =
1068-
lmul!(A, copyto!(C, B))
1070+
lmul!(A, copy!(C, B))
10691071
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1070-
rmul!(copyto!(C, A), B)
1072+
rmul!(copy!(C, A), B)
10711073
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
1072-
lmul!(A, copyto!(C, B))
1074+
lmul!(A, copy!(C, B))
10731075
# redirect for UpperOrLowerTriangular
10741076
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
10751077
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1130,9 +1132,9 @@ end
11301132
ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = _ldiv!(C, A, B)
11311133
# generic fallback for AbstractTriangular, directs to 2-arg [l/r]div!
11321134
_ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) =
1133-
ldiv!(A, copyto!(C, B))
1135+
ldiv!(A, _copy_or_copyto!(C, B))
11341136
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
1135-
rdiv!(copyto!(C, A), B)
1137+
rdiv!(copy!(C, A), B)
11361138
# redirect for UpperOrLowerTriangular to generic_*div!
11371139
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
11381140
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
@@ -1210,7 +1212,7 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
12101212
elseif p == Inf
12111213
return inv(LAPACK.trcon!('I', $uploc, $isunitc, A.data))
12121214
else # use fallback
1213-
return cond(copyto!(similar(parent(A)), A), p)
1215+
return cond(copy!(similar(parent(A)), A), p)
12141216
end
12151217
end
12161218
end
@@ -1236,7 +1238,7 @@ end
12361238
# division
12371239
function generic_trimatdiv!(C::StridedVecOrMat{T}, uploc, isunitc, tfun::Function, A::StridedMatrix{T}, B::AbstractVecOrMat{T}) where {T<:BlasFloat}
12381240
if stride(C,1) == stride(A,1) == 1
1239-
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copy!(C, B))
1241+
LAPACK.trtrs!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : _copy_or_copyto!(C, B))
12401242
else # incompatible with LAPACK
12411243
@invoke generic_trimatdiv!(C::AbstractVecOrMat, uploc, isunitc, tfun::Function, A::AbstractMatrix, B::AbstractVecOrMat)
12421244
end
@@ -1968,22 +1970,22 @@ function powm!(A0::UpperTriangular, p::Real)
19681970
for i in axes(S,1)
19691971
@inbounds S[i, i] = S[i, i] + 1
19701972
end
1971-
copyto!(Stmp, S)
1973+
copy!(Stmp, S)
19721974
mul!(S, A, c)
19731975
ldiv!(Stmp, S)
19741976

19751977
c = (p - j) / (j4 - 2)
19761978
for i in axes(S,1)
19771979
@inbounds S[i, i] = S[i, i] + 1
19781980
end
1979-
copyto!(Stmp, S)
1981+
copy!(Stmp, S)
19801982
mul!(S, A, c)
19811983
ldiv!(Stmp, S)
19821984
end
19831985
for i in axes(S,1)
19841986
S[i, i] = S[i, i] + 1
19851987
end
1986-
copyto!(Stmp, S)
1988+
copy!(Stmp, S)
19871989
mul!(S, A, -p)
19881990
ldiv!(Stmp, S)
19891991
for i in axes(S,1)
@@ -1993,7 +1995,7 @@ function powm!(A0::UpperTriangular, p::Real)
19931995
blockpower!(A0, S, p/(2^s))
19941996
for m = 1:s
19951997
mul!(Stmp.data, S, S)
1996-
copyto!(S, Stmp)
1998+
copy!(S, Stmp)
19971999
blockpower!(A0, S, p/(2^(s-m)))
19982000
end
19992001
rmul!(S, normA0^p)
@@ -2180,7 +2182,7 @@ function _find_params_log_quasitriu!(A)
21802182
break
21812183
end
21822184
_sqrt_quasitriu!(A isa UpperTriangular ? parent(A) : A, A)
2183-
copyto!(AmI, A)
2185+
copy!(AmI, A)
21842186
for i in axes(AmI,1)
21852187
@inbounds AmI[i,i] -= 1
21862188
end

0 commit comments

Comments
 (0)