diff --git a/src/bidiag.jl b/src/bidiag.jl index bb5b8830..90385ab7 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -405,14 +405,15 @@ end function diag(M::Bidiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n - v = similar(M.dv, max(0, length(M.dv)-abs(n))) + dinds = diagind(M, n, IndexStyle(M)) + v = similar(M.dv, length(dinds)) if n == 0 copyto!(v, M.dv) elseif (n == 1 && M.uplo == 'U') || (n == -1 && M.uplo == 'L') copyto!(v, M.ev) elseif -size(M,1) <= n <= size(M,1) - for i in eachindex(v) - v[i] = M[BandIndex(n,i)] + for i in eachindex(v, dinds) + @inbounds v[i] = M[BandIndex(n,i)] end end return v diff --git a/src/diagonal.jl b/src/diagonal.jl index 9f8d54e5..df69fdf8 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -836,12 +836,13 @@ permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D function diag(D::Diagonal, k::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of k - v = similar(D.diag, max(0, length(D.diag)-abs(k))) + dinds = diagind(D, k, IndexStyle(D)) + v = similar(D.diag, length(dinds)) if k == 0 copyto!(v, D.diag) else - for i in eachindex(v) - v[i] = D[BandIndex(k, i)] + for i in eachindex(v, dinds) + @inbounds v[i] = D[dinds[i]] end end return v diff --git a/src/tridiag.jl b/src/tridiag.jl index a24cc50b..5101283c 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -191,7 +191,8 @@ _eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M)) function diag(M::SymTridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n - v = similar(M.dv, max(0, length(M.dv)-abs(n))) + dinds = diagind(M, n, IndexStyle(M)) + v = similar(M.dv, length(dinds)) if n == 0 return copyto!(v, _diagiter(M)) elseif n == 1 @@ -199,7 +200,7 @@ function diag(M::SymTridiagonal, n::Integer=0) elseif n == -1 return copyto!(v, _eviter_transposed(M)) else - for i in eachindex(v) + for i in eachindex(v, dinds) v[i] = M[BandIndex(n,i)] end end @@ -662,7 +663,8 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) function diag(M::Tridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n - v = similar(M.d, max(0, length(M.d)-abs(n))) + dinds = diagind(M, n, IndexStyle(M)) + v = similar(M.d, length(dinds)) if n == 0 copyto!(v, M.d) elseif n == -1 @@ -670,7 +672,7 @@ function diag(M::Tridiagonal, n::Integer=0) elseif n == 1 copyto!(v, M.du) elseif abs(n) <= size(M,1) - for i in eachindex(v) + for i in eachindex(v, dinds) v[i] = M[BandIndex(n,i)] end end