diff --git a/Project.toml b/Project.toml index b2a95c3..8ffc6db 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.4" +version = "0.3.5" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/diaginterface/diaginterface.jl b/src/diaginterface/diaginterface.jl index 0d96692..7baad7b 100644 --- a/src/diaginterface/diaginterface.jl +++ b/src/diaginterface/diaginterface.jl @@ -8,7 +8,7 @@ function diaglength(a::AbstractArray) return minimum(size(a)) end -function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} +@inline function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} @boundscheck checkbounds(a, I) return allequal(Tuple(I)) end @@ -30,14 +30,17 @@ end struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}} diaglength::Int end -function DiagCartesianIndices(axes::Tuple{Vararg{AbstractUnitRange}}) +function DiagCartesianIndices(axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}}) # Check the ranges are one-based. @assert all(isone, first.(axes)) return DiagCartesianIndices{length(axes)}(minimum(length.(axes))) end -function DiagCartesianIndices(dims::Tuple{Vararg{Int}}) +function DiagCartesianIndices(dims::Tuple{Int,Vararg{Int}}) return DiagCartesianIndices(Base.OneTo.(dims)) end +function DiagCartesianIndices(dims::Tuple{}) + return DiagCartesianIndices{0}(0) +end function DiagCartesianIndices(a::AbstractArray) return DiagCartesianIndices(axes(a)) end @@ -46,19 +49,31 @@ function Base.getindex(I::DiagCartesianIndices{N}, i::Int) where {N} return CartesianIndex(ntuple(Returns(i), N)) end +function checkdiagbounds(::Type{Bool}, a::AbstractArray, i::Integer) + Base.require_one_based_indexing(a) + return i ∈ 1:diaglength(a) +end +function checkdiagbounds(a::AbstractArray, i::Integer) + checkdiagbounds(Bool, a, i) || throw(BoundsError(a, ntuple(Returns(i), ndims(a)))) + return nothing +end + +# Convert a linear index along the diagonal to the corresponding +# CartesianIndex. +@inline function diagindex(a::AbstractArray, i::Integer) + @boundscheck checkdiagbounds(a, i) + return CartesianIndex(ntuple(Returns(i), ndims(a))) +end + function diagindices(a::AbstractArray) return diagindices(IndexStyle(a), a) end function diagindices(::IndexLinear, a::AbstractArray) - maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))] + maxdiag = isempty(a) ? 0 : @inbounds LinearIndices(a)[diagindex(a, diaglength(a))] return 1:diagstride(a):maxdiag end function diagindices(::IndexCartesian, a::AbstractArray) return DiagCartesianIndices(a) - # TODO: Define a special iterator for this, i.e. `DiagCartesianIndices`? - return Iterators.map( - i -> CartesianIndex(ntuple(Returns(i), ndims(a))), Base.OneTo(diaglength(a)) - ) end function diagindices(a::AbstractArray{<:Any,0}) diff --git a/test/test_basics.jl b/test/test_basics.jl index 20facc3..922958b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,14 @@ using Test: @test, @testset, @test_broken, @inferred using DiagonalArrays: - DiagonalArrays, DiagonalArray, DiagonalMatrix, δ, delta, diaglength, diagonal, diagview + DiagonalArrays, + DiagonalArray, + DiagonalMatrix, + δ, + delta, + diagindices, + diaglength, + diagonal, + diagview using FillArrays: Fill, Ones using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength using LinearAlgebra: Diagonal @@ -15,6 +23,51 @@ using LinearAlgebra: Diagonal a = fill(one(elt)) @test diaglength(a) == 1 end + @testset "diagindices" begin + a = randn(elt, ()) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:1 + @test isempty(diagindices(IndexCartesian(), a)) + + for a in ( + randn(elt, (0,)), + randn(elt, (0, 0)), + randn(elt, (0, 3)), + randn(elt, (3, 0)), + randn(elt, (0, 0, 0)), + randn(elt, (3, 3, 0)), + ) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:0 + @test isempty(diagindices(IndexCartesian(), a)) + end + + a = randn(elt, (3,)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:3 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:3) + + a = randn(elt, (4,)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:4 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(1:4) + + for a in (randn(elt, (3, 3)), randn(elt, (3, 4))) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:4:9 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) + end + + a = randn(elt, (4, 3)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:5:11 + @test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3)) + + for a in (randn(elt, (3, 3, 3)), randn(elt, (3, 3, 4))) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:13:27 + @test diagindices(IndexCartesian(), a) == + CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) + end + + a = randn(elt, (3, 4, 3)) + @test diagindices(a) == diagindices(IndexLinear(), a) == 1:16:33 + @test diagindices(IndexCartesian(), a) == + CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3)) + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11