Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/symmetriceigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

# preserve HermOrSym wrapper
# Call `copytrito!` instead of `copy_similar` to only copy the matching triangular half
eigencopy_oftype(A::Hermitian, S) = Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
eigencopy_oftype(A::Symmetric, S) = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
eigencopy_oftype(A::Symmetric{<:Complex}, S) = copyto!(similar(parent(A), S), A)
eigencopy_oftype(A::Hermitian, ::Type{S}) where S = Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
eigencopy_oftype(A::Symmetric, ::Type{S}) where S = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
eigencopy_oftype(A::Symmetric{<:Complex}, ::Type{S}) where S = copyto!(similar(parent(A), S), A)

"""
default_eigen_alg(A)
Expand Down
119 changes: 83 additions & 36 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false

@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
ifelse(i == j, oneunit(T), zero(T))
end
end
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
@inbounds diagzero(A,i,j)
end
end

_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
Expand All @@ -242,62 +254,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
ifelse(b.band == 0, oneunit(T), zero(T))
end
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
@inbounds diagzero(A, b)
end
end

_zero_triangular_half_str(T::Type) = T <: UpperOrUnitUpperTriangular ? "lower" : "upper"

@noinline function throw_nonzeroerror(T::DataType, @nospecialize(x), i, j)
Ts = _zero_triangular_half_str(T)
Tn = nameof(T)
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
nstr = Tn === :UpperTriangular ? "n" : ""
throw(ArgumentError(
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
LazyString(
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
)
)
end
@noinline function throw_nononeerror(T::DataType, @nospecialize(x), i, j)
Tn = nameof(T)
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
throw(ArgumentError(
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
end

@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end

@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end
Expand Down Expand Up @@ -542,7 +589,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
@eval @inline function _copyto!(A::$UT, B::$T)
for dind in diagind(A, IndexStyle(A))
if A[dind] != B[dind]
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
end
end
_copyto!($T(parent(A)), B)
Expand Down Expand Up @@ -696,7 +743,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
checksize1(A, B)
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -707,7 +754,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
checksize1(A, B)
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down Expand Up @@ -738,7 +785,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
checksize1(A, B)
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -749,7 +796,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
checksize1(A, B)
iszero(_add.alpha) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down
79 changes: 77 additions & 2 deletions test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -644,11 +644,11 @@ end
@testset "error message" begin
A = UpperTriangular(Ap)
B = UpperTriangular(Bp)
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)

A = LowerTriangular(Ap)
B = LowerTriangular(Bp)
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
end
end

Expand Down Expand Up @@ -944,6 +944,81 @@ end
@test 2\U == 2\M
@test U*2 == M*2
@test 2*U == 2*M

U2 = copy(U)
@test rmul!(U, 1) == U2
@test lmul!(1, U) == U2
end

@testset "indexing checks" begin
P = [1 2; 3 4]
@testset "getindex" begin
U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0]
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(-1,0)]

U = UpperTriangular(P)
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(-1,0)]

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0]
@test_throws BoundsError L[0,1]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(1,0)]

L = LowerTriangular(P)
@test_throws BoundsError L[0,1]
@test_throws BoundsError L[BandIndex(1,0)]
end
@testset "setindex!" begin
A = SizedArrays.SizedArray{(2,2)}(P)
M = fill(A, 2, 2)
U = UnitUpperTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
@test_throws non_unit_msg U[1,1] = A
L = UnitLowerTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
@test_throws non_unit_msg L[1,1] = A

for UT in (UnitUpperTriangular, UpperTriangular)
U = UT(M)
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
end
for LT in (UnitLowerTriangular, LowerTriangular)
L = LT(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
end

U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0] = 1
@test_throws BoundsError U[1,0] = 0

U = UpperTriangular(P)
@test_throws BoundsError U[1,0] = 0

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0] = 1
@test_throws BoundsError L[0,1] = 0

L = LowerTriangular(P)
@test_throws BoundsError L[0,1] = 0
end
end

@testset "unit triangular l/rdiv!" begin
A = rand(3,3)
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
(UnitLowerTriangular, LowerTriangular))
UnitTri = UT(A)
Tri = T(LinearAlgebra.full(UnitTri))
@test 2 \ UnitTri ≈ 2 \ Tri
@test UnitTri / 2 ≈ Tri / 2
end
end

end # module TestTriangular