Skip to content

Commit 8ea8f55

Browse files
committed
Unaliasing and short-circuiting in copytrito!
1 parent 07725da commit 8ea8f55

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/generic.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,26 +2088,32 @@ julia> copytrito!(B, A, 'L')
20882088
function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
20892089
require_one_based_indexing(A, B)
20902090
BLAS.chkuplo(uplo)
2091+
B === A && return B
20912092
m,n = size(A)
20922093
A = Base.unalias(B, A)
20932094
if uplo == 'U'
20942095
LAPACK.lacpy_size_check(size(B), (n < m ? n : m, n))
2096+
# extract the parents for UpperTriangular matrices
2097+
Bv, Av = uppertridata(B), uppertridata(A)
20952098
for j in axes(A,2), i in axes(A,1)[begin : min(j,end)]
2096-
# extract the parents for UpperTriangular matrices
2097-
Bv, Av = uppertridata(B), uppertridata(A)
20982099
@inbounds Bv[i,j] = Av[i,j]
20992100
end
21002101
else # uplo == 'L'
21012102
LAPACK.lacpy_size_check(size(B), (m, m < n ? m : n))
2103+
# extract the parents for LowerTriangular matrices
2104+
Bv, Av = lowertridata(B), lowertridata(A)
21022105
for j in axes(A,2), i in axes(A,1)[j:end]
2103-
# extract the parents for LowerTriangular matrices
2104-
Bv, Av = lowertridata(B), lowertridata(A)
21052106
@inbounds Bv[i,j] = Av[i,j]
21062107
end
21072108
end
21082109
return B
21092110
end
21102111
# Forward LAPACK-compatible strided matrices to lacpy
21112112
function copytrito!(B::StridedMatrixStride1{T}, A::StridedMatrixStride1{T}, uplo::AbstractChar) where {T<:BlasFloat}
2113+
require_one_based_indexing(A, B)
2114+
BLAS.chkuplo(uplo)
2115+
B === A && return B
2116+
A = Base.unalias(B, A)
21122117
LAPACK.lacpy!(B, A, uplo)
2118+
return B
21132119
end

test/generic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,4 +879,13 @@ end
879879
@test mul!(copy!(similar(v), v), v, 2, 2, 0) == 4v
880880
end
881881

882+
@testset "aliasing in copytrito! for strided matrices" begin
883+
M = rand(4, 1)
884+
A = view(M, 1:3, 1:1)
885+
A2 = copy(A)
886+
B = view(M, 2:4, 1:1)
887+
copytrito!(B, A, 'L')
888+
@test B == A2
889+
end
890+
882891
end # module TestGeneric

0 commit comments

Comments
 (0)