Skip to content

Commit 00dfafa

Browse files
authored
Fix empty diagindices, add more tests (#26)
1 parent b54a799 commit 00dfafa

File tree

3 files changed

+78
-10
lines changed

3 files changed

+78
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/diaginterface/diaginterface.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function diaglength(a::AbstractArray)
88
return minimum(size(a))
99
end
1010

11-
function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
11+
@inline function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
1212
@boundscheck checkbounds(a, I)
1313
return allequal(Tuple(I))
1414
end
@@ -30,14 +30,17 @@ end
3030
struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}}
3131
diaglength::Int
3232
end
33-
function DiagCartesianIndices(axes::Tuple{Vararg{AbstractUnitRange}})
33+
function DiagCartesianIndices(axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}})
3434
# Check the ranges are one-based.
3535
@assert all(isone, first.(axes))
3636
return DiagCartesianIndices{length(axes)}(minimum(length.(axes)))
3737
end
38-
function DiagCartesianIndices(dims::Tuple{Vararg{Int}})
38+
function DiagCartesianIndices(dims::Tuple{Int,Vararg{Int}})
3939
return DiagCartesianIndices(Base.OneTo.(dims))
4040
end
41+
function DiagCartesianIndices(dims::Tuple{})
42+
return DiagCartesianIndices{0}(0)
43+
end
4144
function DiagCartesianIndices(a::AbstractArray)
4245
return DiagCartesianIndices(axes(a))
4346
end
@@ -46,19 +49,31 @@ function Base.getindex(I::DiagCartesianIndices{N}, i::Int) where {N}
4649
return CartesianIndex(ntuple(Returns(i), N))
4750
end
4851

52+
function checkdiagbounds(::Type{Bool}, a::AbstractArray, i::Integer)
53+
Base.require_one_based_indexing(a)
54+
return i 1:diaglength(a)
55+
end
56+
function checkdiagbounds(a::AbstractArray, i::Integer)
57+
checkdiagbounds(Bool, a, i) || throw(BoundsError(a, ntuple(Returns(i), ndims(a))))
58+
return nothing
59+
end
60+
61+
# Convert a linear index along the diagonal to the corresponding
62+
# CartesianIndex.
63+
@inline function diagindex(a::AbstractArray, i::Integer)
64+
@boundscheck checkdiagbounds(a, i)
65+
return CartesianIndex(ntuple(Returns(i), ndims(a)))
66+
end
67+
4968
function diagindices(a::AbstractArray)
5069
return diagindices(IndexStyle(a), a)
5170
end
5271
function diagindices(::IndexLinear, a::AbstractArray)
53-
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
72+
maxdiag = isempty(a) ? 0 : @inbounds LinearIndices(a)[diagindex(a, diaglength(a))]
5473
return 1:diagstride(a):maxdiag
5574
end
5675
function diagindices(::IndexCartesian, a::AbstractArray)
5776
return DiagCartesianIndices(a)
58-
# TODO: Define a special iterator for this, i.e. `DiagCartesianIndices`?
59-
return Iterators.map(
60-
i -> CartesianIndex(ntuple(Returns(i), ndims(a))), Base.OneTo(diaglength(a))
61-
)
6277
end
6378

6479
function diagindices(a::AbstractArray{<:Any,0})

test/test_basics.jl

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
using Test: @test, @testset, @test_broken, @inferred
22
using DiagonalArrays:
3-
DiagonalArrays, DiagonalArray, DiagonalMatrix, δ, delta, diaglength, diagonal, diagview
3+
DiagonalArrays,
4+
DiagonalArray,
5+
DiagonalMatrix,
6+
δ,
7+
delta,
8+
diagindices,
9+
diaglength,
10+
diagonal,
11+
diagview
412
using FillArrays: Fill, Ones
513
using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength
614
using LinearAlgebra: Diagonal
@@ -15,6 +23,51 @@ using LinearAlgebra: Diagonal
1523
a = fill(one(elt))
1624
@test diaglength(a) == 1
1725
end
26+
@testset "diagindices" begin
27+
a = randn(elt, ())
28+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:1
29+
@test isempty(diagindices(IndexCartesian(), a))
30+
31+
for a in (
32+
randn(elt, (0,)),
33+
randn(elt, (0, 0)),
34+
randn(elt, (0, 3)),
35+
randn(elt, (3, 0)),
36+
randn(elt, (0, 0, 0)),
37+
randn(elt, (3, 3, 0)),
38+
)
39+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:0
40+
@test isempty(diagindices(IndexCartesian(), a))
41+
end
42+
43+
a = randn(elt, (3,))
44+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:3
45+
@test diagindices(IndexCartesian(), a) == CartesianIndex.(1:3)
46+
47+
a = randn(elt, (4,))
48+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:4
49+
@test diagindices(IndexCartesian(), a) == CartesianIndex.(1:4)
50+
51+
for a in (randn(elt, (3, 3)), randn(elt, (3, 4)))
52+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:4:9
53+
@test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3))
54+
end
55+
56+
a = randn(elt, (4, 3))
57+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:5:11
58+
@test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3))
59+
60+
for a in (randn(elt, (3, 3, 3)), randn(elt, (3, 3, 4)))
61+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:13:27
62+
@test diagindices(IndexCartesian(), a) ==
63+
CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3))
64+
end
65+
66+
a = randn(elt, (3, 4, 3))
67+
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:16:33
68+
@test diagindices(IndexCartesian(), a) ==
69+
CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3))
70+
end
1871
@testset "Matrix multiplication" begin
1972
a1 = DiagonalArray{elt}(undef, (2, 3))
2073
a1[1, 1] = 11

0 commit comments

Comments
 (0)