Skip to content

Commit 9a0d209

Browse files
authored
Use IndexStyle in diagind to optionally return a range of CartesianIndexes (#52115)
1 parent 71ee30f commit 9a0d209

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,17 @@ function fillband!(A::AbstractMatrix{T}, x, l, u) where T
198198
return A
199199
end
200200

201-
diagind(m::Integer, n::Integer, k::Integer=0) =
201+
diagind(m::Integer, n::Integer, k::Integer=0) = diagind(IndexLinear(), m, n, k)
202+
diagind(::IndexLinear, m::Integer, n::Integer, k::Integer=0) =
202203
k <= 0 ? range(1-k, step=m+1, length=min(m+k, n)) : range(k*m+1, step=m+1, length=min(m, n-k))
203204

205+
function diagind(::IndexCartesian, m::Integer, n::Integer, k::Integer=0)
206+
Cstart = CartesianIndex(1 + max(0,-k), 1 + max(0,k))
207+
Cstep = CartesianIndex(1, 1)
208+
length = max(0, k <= 0 ? min(m+k, n) : min(m, n-k))
209+
StepRangeLen(Cstart, Cstep, length)
210+
end
211+
204212
"""
205213
diagind(M, k::Integer=0)
206214
@@ -222,7 +230,7 @@ julia> diagind(A,-1)
222230
"""
223231
function diagind(A::AbstractMatrix, k::Integer=0)
224232
require_one_based_indexing(A)
225-
diagind(size(A,1), size(A,2), k)
233+
diagind(IndexStyle(A), size(A,1), size(A,2), k)
226234
end
227235

228236
"""

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,4 +843,12 @@ end
843843
@test all(iszero, diag(B, 1))
844844
end
845845

846+
@testset "diagind" begin
847+
B = Bidiagonal(1:4, 1:3, :U)
848+
M = Matrix(B)
849+
@testset for k in -4:4
850+
@test B[diagind(B,k)] == M[diagind(M,k)]
851+
end
852+
end
853+
846854
end # module TestBidiagonal

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,4 +1227,12 @@ end
12271227
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal
12281228
end
12291229

1230+
@testset "diagind" begin
1231+
D = Diagonal(1:4)
1232+
M = Matrix(D)
1233+
@testset for k in -4:4
1234+
@test D[diagind(D,k)] == M[diagind(M,k)]
1235+
end
1236+
end
1237+
12301238
end # module TestDiagonal

0 commit comments

Comments
 (0)