Skip to content

Commit 90161b9

Browse files
committed
Unaliasing and short-circuiting in copytrito!
1 parent 222f7f2 commit 90161b9

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

src/generic.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,26 +2075,32 @@ julia> copytrito!(B, A, 'L')
20752075
function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
20762076
require_one_based_indexing(A, B)
20772077
BLAS.chkuplo(uplo)
2078+
B === A && return B
20782079
m,n = size(A)
20792080
A = Base.unalias(B, A)
20802081
if uplo == 'U'
20812082
LAPACK.lacpy_size_check(size(B), (n < m ? n : m, n))
2083+
# extract the parents for UpperTriangular matrices
2084+
Bv, Av = uppertridata(B), uppertridata(A)
20822085
for j in axes(A,2), i in axes(A,1)[begin : min(j,end)]
2083-
# extract the parents for UpperTriangular matrices
2084-
Bv, Av = uppertridata(B), uppertridata(A)
20852086
@inbounds Bv[i,j] = Av[i,j]
20862087
end
20872088
else # uplo == 'L'
20882089
LAPACK.lacpy_size_check(size(B), (m, m < n ? m : n))
2090+
# extract the parents for LowerTriangular matrices
2091+
Bv, Av = lowertridata(B), lowertridata(A)
20892092
for j in axes(A,2), i in axes(A,1)[j:end]
2090-
# extract the parents for LowerTriangular matrices
2091-
Bv, Av = lowertridata(B), lowertridata(A)
20922093
@inbounds Bv[i,j] = Av[i,j]
20932094
end
20942095
end
20952096
return B
20962097
end
20972098
# Forward LAPACK-compatible strided matrices to lacpy
20982099
function copytrito!(B::StridedMatrixStride1{T}, A::StridedMatrixStride1{T}, uplo::AbstractChar) where {T<:BlasFloat}
2100+
require_one_based_indexing(A, B)
2101+
BLAS.chkuplo(uplo)
2102+
B === A && return B
2103+
A = Base.unalias(B, A)
20992104
LAPACK.lacpy!(B, A, uplo)
2105+
return B
21002106
end

test/generic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,4 +857,13 @@ end
857857
end
858858
end
859859

860+
@testset "aliasing in copytrito! for strided matrices" begin
861+
M = rand(4, 1)
862+
A = view(M, 1:3, 1:1)
863+
A2 = copy(A)
864+
B = view(M, 2:4, 1:1)
865+
copytrito!(B, A, 'L')
866+
@test B == A2
867+
end
868+
860869
end # module TestGeneric

0 commit comments

Comments
 (0)