diff --git a/src/bidiag.jl b/src/bidiag.jl index 19a80336..13439ff9 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -178,6 +178,19 @@ end return A end +@inline function setindex!(A::Bidiagonal, x, b::BandIndex) + @boundscheck checkbounds(A, b) + if b.band == 0 + @inbounds A.dv[b.index] = x + elseif b.band ∈ (-1,1) && b.band == _offdiagind(A.uplo) + @inbounds A.ev[b.index] = x + elseif !iszero(x) + throw(ArgumentError(LazyString(lazy"cannot set entry $(to_indices(A, (b,))) off the ", + A.uplo == 'U' ? "upper" : "lower", " bidiagonal band to a nonzero value ", x))) + end + return A +end + Base._reverse(A::Bidiagonal, dims) = reverse!(Matrix(A); dims) Base._reverse(A::Bidiagonal, ::Colon) = Bidiagonal(reverse(A.dv), reverse(A.ev), A.uplo == 'U' ? :L : :U) diff --git a/src/diagonal.jl b/src/diagonal.jl index 10b8e7b0..3f72c16c 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -218,7 +218,7 @@ zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatr r end -function setindex!(D::Diagonal, v, i::Int, j::Int) +@inline function setindex!(D::Diagonal, v, i::Int, j::Int) @boundscheck checkbounds(D, i, j) if i == j @inbounds D.diag[i] = v @@ -228,6 +228,15 @@ function setindex!(D::Diagonal, v, i::Int, j::Int) return D end +@inline function setindex!(D::Diagonal, v, b::BandIndex) + @boundscheck checkbounds(D, b) + if b.band == 0 + @inbounds D.diag[b.index] = v + elseif !iszero(v) + throw(ArgumentError(lazy"cannot set off-diagonal entry $(to_indices(D, (b,))) to a nonzero value ($v)")) + end + return D +end ## structured matrix methods ## function Base.replace_in_print_matrix(A::Diagonal,i::Integer,j::Integer,s::AbstractString) diff --git a/src/tridiag.jl b/src/tridiag.jl index 808f5cc8..c2290104 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -508,6 +508,17 @@ Base._reverse!(A::SymTridiagonal, dims::Colon) = (reverse!(A.dv); reverse!(A.ev) return A end +@inline function setindex!(A::SymTridiagonal, x, b::BandIndex) + @boundscheck checkbounds(A, b) + if b.band == 0 + issymmetric(x) || throw(ArgumentError("cannot set a diagonal entry of a SymTridiagonal to an asymmetric value")) + @inbounds A.dv[b.index] = x + else + throw(ArgumentError(lazy"cannot set off-diagonal entry $(to_indices(A, (b,)))")) + end + return A +end + ## Tridiagonal matrices ## struct Tridiagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T} dl::V # sub-diagonal @@ -775,6 +786,21 @@ end return A end +@inline function setindex!(A::Tridiagonal, x, b::BandIndex) + @boundscheck checkbounds(A, b) + if b.band == 0 + @inbounds A.d[b.index] = x + elseif b.band == -1 + @inbounds A.dl[b.index] = x + elseif b.band == 1 + @inbounds A.du[b.index] = x + elseif !iszero(x) + throw(ArgumentError(LazyString(lazy"cannot set entry $(to_indices(A, (b,))) off ", + lazy"the tridiagonal band to a nonzero value ($x)"))) + end + return A +end + ## structured matrix methods ## function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::AbstractString) i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s) diff --git a/test/bidiag.jl b/test/bidiag.jl index 2488cd3f..fcd9de77 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1205,4 +1205,21 @@ end @test rmul!(B, D) == B2 end +@testset "setindex! with BandIndex" begin + B = Bidiagonal(zeros(3), zeros(2), :U) + B[LinearAlgebra.BandIndex(0,2)] = 1 + @test B[2,2] == 1 + B[LinearAlgebra.BandIndex(1,1)] = 2 + @test B[1,2] == 2 + @test_throws "cannot set entry $((1,3)) off the upper bidiagonal band" B[LinearAlgebra.BandIndex(2,1)] = 2 + + B = Bidiagonal(zeros(3), zeros(2), :L) + B[LinearAlgebra.BandIndex(-1,1)] = 2 + @test B[2,1] == 2 + @test_throws "cannot set entry $((3,1)) off the lower bidiagonal band" B[LinearAlgebra.BandIndex(-2,1)] = 2 + + @test_throws BoundsError B[LinearAlgebra.BandIndex(size(B,1),1)] + @test_throws BoundsError B[LinearAlgebra.BandIndex(0,size(B,1)+1)] +end + end # module TestBidiagonal diff --git a/test/diagonal.jl b/test/diagonal.jl index cbd9edb4..207e52d9 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1489,4 +1489,13 @@ end @test !isreal(im*D) end +@testset "setindex! with BandIndex" begin + D = Diagonal(zeros(2)) + D[LinearAlgebra.BandIndex(0,2)] = 1 + @test D[2,2] == 1 + @test_throws "cannot set off-diagonal entry $((1,2))" D[LinearAlgebra.BandIndex(1,1)] = 1 + @test_throws BoundsError D[LinearAlgebra.BandIndex(size(D,1),1)] + @test_throws BoundsError D[LinearAlgebra.BandIndex(0,size(D,1)+1)] +end + end # module TestDiagonal diff --git a/test/tridiag.jl b/test/tridiag.jl index 849dfa17..115b7ab2 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1184,4 +1184,27 @@ end @test convert(SymTridiagonal, S) == S end +@testset "setindex! with BandIndex" begin + T = Tridiagonal(zeros(3), zeros(4), zeros(3)) + T[LinearAlgebra.BandIndex(0,2)] = 1 + @test T[2,2] == 1 + T[LinearAlgebra.BandIndex(1,2)] = 2 + @test T[2,3] == 2 + T[LinearAlgebra.BandIndex(-1,2)] = 3 + @test T[3,2] == 3 + + @test_throws "cannot set entry $((1,3)) off the tridiagonal band" T[LinearAlgebra.BandIndex(2,1)] = 1 + @test_throws "cannot set entry $((3,1)) off the tridiagonal band" T[LinearAlgebra.BandIndex(-2,1)] = 1 + @test_throws BoundsError T[LinearAlgebra.BandIndex(size(T,1),1)] + @test_throws BoundsError T[LinearAlgebra.BandIndex(0,size(T,1)+1)] + + S = SymTridiagonal(zeros(4), zeros(3)) + S[LinearAlgebra.BandIndex(0,2)] = 1 + @test S[2,2] == 1 + + @test_throws "cannot set off-diagonal entry $((1,3))" S[LinearAlgebra.BandIndex(2,1)] = 1 + @test_throws BoundsError S[LinearAlgebra.BandIndex(size(S,1),1)] + @test_throws BoundsError S[LinearAlgebra.BandIndex(0,size(S,1)+1)] +end + end # module TestTridiagonal