Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.4"
version = "0.3.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
31 changes: 23 additions & 8 deletions src/diaginterface/diaginterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
return minimum(size(a))
end

function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
@inline function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
@boundscheck checkbounds(a, I)
return allequal(Tuple(I))
end
Expand All @@ -30,14 +30,17 @@
struct DiagCartesianIndices{N} <: AbstractVector{CartesianIndex{N}}
diaglength::Int
end
function DiagCartesianIndices(axes::Tuple{Vararg{AbstractUnitRange}})
function DiagCartesianIndices(axes::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}})
# Check the ranges are one-based.
@assert all(isone, first.(axes))
return DiagCartesianIndices{length(axes)}(minimum(length.(axes)))
end
function DiagCartesianIndices(dims::Tuple{Vararg{Int}})
function DiagCartesianIndices(dims::Tuple{Int,Vararg{Int}})

Check warning on line 38 in src/diaginterface/diaginterface.jl

View check run for this annotation

Codecov / codecov/patch

src/diaginterface/diaginterface.jl#L38

Added line #L38 was not covered by tests
return DiagCartesianIndices(Base.OneTo.(dims))
end
function DiagCartesianIndices(dims::Tuple{})
return DiagCartesianIndices{0}(0)
end
function DiagCartesianIndices(a::AbstractArray)
return DiagCartesianIndices(axes(a))
end
Expand All @@ -46,19 +49,31 @@
return CartesianIndex(ntuple(Returns(i), N))
end

function checkdiagbounds(::Type{Bool}, a::AbstractArray, i::Integer)
Base.require_one_based_indexing(a)
return i ∈ 1:diaglength(a)
end
function checkdiagbounds(a::AbstractArray, i::Integer)
checkdiagbounds(Bool, a, i) || throw(BoundsError(a, ntuple(Returns(i), ndims(a))))
return nothing
end

# Convert a linear index along the diagonal to the corresponding
# CartesianIndex.
@inline function diagindex(a::AbstractArray, i::Integer)
@boundscheck checkdiagbounds(a, i)
return CartesianIndex(ntuple(Returns(i), ndims(a)))
end

function diagindices(a::AbstractArray)
return diagindices(IndexStyle(a), a)
end
function diagindices(::IndexLinear, a::AbstractArray)
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
maxdiag = isempty(a) ? 0 : @inbounds LinearIndices(a)[diagindex(a, diaglength(a))]
return 1:diagstride(a):maxdiag
end
function diagindices(::IndexCartesian, a::AbstractArray)
return DiagCartesianIndices(a)
# TODO: Define a special iterator for this, i.e. `DiagCartesianIndices`?
return Iterators.map(
i -> CartesianIndex(ntuple(Returns(i), ndims(a))), Base.OneTo(diaglength(a))
)
end

function diagindices(a::AbstractArray{<:Any,0})
Expand Down
55 changes: 54 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using Test: @test, @testset, @test_broken, @inferred
using DiagonalArrays:
DiagonalArrays, DiagonalArray, DiagonalMatrix, δ, delta, diaglength, diagonal, diagview
DiagonalArrays,
DiagonalArray,
DiagonalMatrix,
δ,
delta,
diagindices,
diaglength,
diagonal,
diagview
using FillArrays: Fill, Ones
using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength
using LinearAlgebra: Diagonal
Expand All @@ -15,6 +23,51 @@ using LinearAlgebra: Diagonal
a = fill(one(elt))
@test diaglength(a) == 1
end
@testset "diagindices" begin
a = randn(elt, ())
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:1
@test isempty(diagindices(IndexCartesian(), a))

for a in (
randn(elt, (0,)),
randn(elt, (0, 0)),
randn(elt, (0, 3)),
randn(elt, (3, 0)),
randn(elt, (0, 0, 0)),
randn(elt, (3, 3, 0)),
)
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:0
@test isempty(diagindices(IndexCartesian(), a))
end

a = randn(elt, (3,))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:3
@test diagindices(IndexCartesian(), a) == CartesianIndex.(1:3)

a = randn(elt, (4,))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:1:4
@test diagindices(IndexCartesian(), a) == CartesianIndex.(1:4)

for a in (randn(elt, (3, 3)), randn(elt, (3, 4)))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:4:9
@test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3))
end

a = randn(elt, (4, 3))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:5:11
@test diagindices(IndexCartesian(), a) == CartesianIndex.(Iterators.zip(1:3, 1:3))

for a in (randn(elt, (3, 3, 3)), randn(elt, (3, 3, 4)))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:13:27
@test diagindices(IndexCartesian(), a) ==
CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3))
end

a = randn(elt, (3, 4, 3))
@test diagindices(a) == diagindices(IndexLinear(), a) == 1:16:33
@test diagindices(IndexCartesian(), a) ==
CartesianIndex.(Iterators.zip(1:3, 1:3, 1:3))
end
@testset "Matrix multiplication" begin
a1 = DiagonalArray{elt}(undef, (2, 3))
a1[1, 1] = 11
Expand Down
Loading