diff --git a/Project.toml b/Project.toml index 79156fd..68d0814 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.2.10" +version = "0.2.11" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" @@ -17,6 +18,7 @@ Aqua = "0.8.9" ArrayLayouts = "1.11.0" DerivableInterfaces = "0.3.7" Dictionaries = "0.4.3" +FillArrays = "1.13.0" LinearAlgebra = "1.10" MapBroadcast = "0.1.5" SafeTestsets = "0.1" diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index ac69d16..0fa0ad2 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -3,8 +3,12 @@ module SparseArraysBase export SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, + OneElementArray, + OneElementMatrix, + OneElementVector, eachstoredindex, isstored, + oneelementarray, storedlength, storedpairs, storedvalues @@ -14,5 +18,6 @@ include("sparsearrayinterface.jl") include("wrappers.jl") include("abstractsparsearray.jl") include("sparsearraydok.jl") +include("oneelementarray.jl") end diff --git a/src/oneelementarray.jl b/src/oneelementarray.jl new file mode 100644 index 0000000..122d65c --- /dev/null +++ b/src/oneelementarray.jl @@ -0,0 +1,275 @@ +using FillArrays: Fill + +# Like [`FillArrays.OneElement`](https://github.com/JuliaArrays/FillArrays.jl) +# and [`OneHotArrays.OneHotArray`](https://github.com/FluxML/OneHotArrays.jl). +struct OneElementArray{T,N,I,A,F} <: AbstractSparseArray{T,N} + value::T + index::I + axes::A + getunstoredindex::F +end + +using DerivableInterfaces: @array_aliases +# Define `OneElementMatrix`, `AnyOneElementArray`, etc. +@array_aliases OneElementArray + +function OneElementArray{T,N}( + value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}, getunstoredindex +) where {T,N} + return OneElementArray{T,N,typeof(index),typeof(axes),typeof(getunstoredindex)}( + value, index, axes, getunstoredindex + ) +end + +function OneElementArray{T,N}( + value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {T,N} + return OneElementArray{T,N}(value, index, axes, default_getunstoredindex) +end +function OneElementArray{<:Any,N}( + value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {T,N} + return OneElementArray{T,N}(value, index, axes) +end +function OneElementArray( + value::T, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {T,N} + return OneElementArray{T,N}(value, index, axes) +end + +function OneElementArray{T,N}( + index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {T,N} + return OneElementArray{T,N}(one(T), index, axes) +end +function OneElementArray{<:Any,N}( + index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {N} + return OneElementArray{Bool,N}(index, axes) +end +function OneElementArray{T}( + index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {T,N} + return OneElementArray{T,N}(index, axes) +end +function OneElementArray(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N} + return OneElementArray{Bool,N}(index, axes) +end + +function OneElementArray{T,N}( + value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N} +) where {T,N} + return OneElementArray{T,N}(value, last.(ax_ind), first.(ax_ind)) +end +function OneElementArray{<:Any,N}( + value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N} +) where {T,N} + return OneElementArray{T,N}(value, ax_ind...) +end +function OneElementArray{T}( + value, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N} +) where {T,N} + return OneElementArray{T,N}(value, ax_ind...) +end +function OneElementArray( + value::T, ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N} +) where {T,N} + return OneElementArray{T,N}(value, ax_ind...) +end + +function OneElementArray{T,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N} + return OneElementArray{T,N}(last.(ax_ind), first.(ax_ind)) +end +function OneElementArray{<:Any,N}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N} + return OneElementArray{Bool,N}(ax_ind...) +end +function OneElementArray{T}(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {T,N} + return OneElementArray{T,N}(ax_ind...) +end +function OneElementArray(ax_ind::Vararg{Pair{<:AbstractUnitRange,Int},N}) where {N} + return OneElementArray{Bool,N}(ax_ind...) +end + +# Fix ambiguity errors. +function OneElementArray{T,0}(value, index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(value, index, axes, default_getunstoredindex) +end +function OneElementArray{<:Any,0}(value::T, index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(value, index, axes) +end +function OneElementArray{T}(value, index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(value, index, axes) +end +function OneElementArray(value::T, index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(value, index, axes) +end + +# Fix ambiguity errors. +function OneElementArray{T,0}(index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(one(T), index, axes) +end +function OneElementArray{<:Any,0}(index::Tuple{}, axes::Tuple{}) + return OneElementArray{Bool,0}(index, axes) +end +function OneElementArray{T}(index::Tuple{}, axes::Tuple{}) where {T} + return OneElementArray{T,0}(index, axes) +end +function OneElementArray(index::Tuple{}, axes::Tuple{}) + return OneElementArray{Bool,0}(value, index, axes) +end + +function OneElementArray{T,0}(value) where {T} + return OneElementArray{T,0}(value, (), ()) +end +function OneElementArray{<:Any,0}(value::T) where {T} + return OneElementArray{T,0}(value) +end +function OneElementArray{T}(value) where {T} + return OneElementArray{T,0}(value) +end +function OneElementArray(value::T) where {T} + return OneElementArray{T}(value) +end + +function OneElementArray{T,0}() where {T} + return OneElementArray{T,0}((), ()) +end +function OneElementArray{<:Any,0}() + return OneElementArray{Bool,0}(value) +end +function OneElementArray{T}() where {T} + return OneElementArray{T,0}() +end +function OneElementArray() + return OneElementArray{Bool}() +end + +function OneElementArray{T,N}( + value, index::NTuple{N,Int}, size::NTuple{N,Integer} +) where {T,N} + return OneElementArray{T,N}(value, index, Base.oneto.(size)) +end +function OneElementArray{<:Any,N}( + value::T, index::NTuple{N,Int}, size::NTuple{N,Integer} +) where {T,N} + return OneElementArray{T,N}(value, index, size) +end +function OneElementArray{T}( + value, index::NTuple{N,Int}, size::NTuple{N,Integer} +) where {T,N} + return OneElementArray{T,N}(value, index, size) +end +function OneElementArray( + value::T, index::NTuple{N,Int}, size::NTuple{N,Integer} +) where {T,N} + return OneElementArray{T,N}(value, index, Base.oneto.(size)) +end + +function OneElementArray{T,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N} + return OneElementArray{T,N}(one(T), index, size) +end +function OneElementArray{<:Any,N}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N} + return OneElementArray{Bool,N}(index, size) +end +function OneElementArray{T}(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {T,N} + return OneElementArray{T,N}(index, size) +end +function OneElementArray(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N} + return OneElementArray{Bool,N}(index, size) +end + +function OneElementVector{T}(value, index::Int, length::Integer) where {T} + return OneElementVector{T}(value, (index,), (length,)) +end +function OneElementVector(value::T, index::Int, length::Integer) where {T} + return OneElementVector{T}(value, index, length) +end +function OneElementArray{T}(value, index::Int, length::Integer) where {T} + return OneElementVector{T}(value, index, length) +end +function OneElementArray(value::T, index::Int, length::Integer) where {T} + return OneElementVector{T}(value, index, length) +end + +function OneElementVector{T}(index::Int, size::Integer) where {T} + return OneElementVector{T}((index,), (size,)) +end +function OneElementVector(index::Int, length::Integer) + return OneElementVector{Bool}(index, length) +end +function OneElementArray{T}(index::Int, size::Integer) where {T} + return OneElementVector{T}(index, size) +end +OneElementArray(index::Int, size::Integer) = OneElementVector{Bool}(index, size) + +# Interface to overload for constructing arrays like `OneElementArray`, +# that may not be `OneElementArray` (i.e. wrapped versions). +function oneelement( + value, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {N} + return OneElementArray(value, index, axes) +end +function oneelement( + eltype::Type, index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange} +) where {N} + return oneelement(one(eltype), index, axes) +end +function oneelement(index::NTuple{N,Int}, axes::NTuple{N,AbstractUnitRange}) where {N} + return oneelement(Bool, index, axes) +end + +function oneelement(value, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N} + return oneelement(value, index, Base.oneto.(size)) +end +function oneelement(eltype::Type, index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N} + return oneelement(one(eltype), index, size) +end +function oneelement(index::NTuple{N,Int}, size::NTuple{N,Integer}) where {N} + return oneelement(Bool, index, size) +end + +function oneelement(value, ax_ind::Pair{<:AbstractUnitRange,Int}...) + return oneelement(value, last.(ax_ind), first.(ax_ind)) +end +function oneelement(eltype::Type, ax_ind::Pair{<:AbstractUnitRange,Int}...) + return oneelement(one(eltype), ax_ind...) +end +function oneelement(ax_ind::Pair{<:AbstractUnitRange,Int}...) + return oneelement(Bool, ax_ind...) +end + +function oneelement(value) + return oneelement(value, (), ()) +end +function oneelement(eltype::Type) + return oneelement(one(eltype)) +end +function oneelement() + return oneelement(Bool) +end + +Base.axes(a::OneElementArray) = getfield(a, :axes) +Base.size(a::OneElementArray) = length.(axes(a)) +storedvalue(a::OneElementArray) = getfield(a, :value) +storedvalues(a::OneElementArray) = Fill(storedvalue(a), 1) + +storedindex(a::OneElementArray) = getfield(a, :index) +function isstored(a::OneElementArray, I::Int...) + return I == storedindex(a) +end +function eachstoredindex(a::OneElementArray) + return Fill(CartesianIndex(storedindex(a)), 1) +end + +function getstoredindex(a::OneElementArray, I::Int...) + return storedvalue(a) +end +function getunstoredindex(a::OneElementArray, I::Int...) + return a.getunstoredindex(a, I...) +end +function setstoredindex!(a::OneElementArray, value, I::Int...) + return error("`OneElementArray` is immutable, you can't set elements.") +end +function setunstoredindex!(a::OneElementArray, value, I::Int...) + return error("`OneElementArray` is immutable, you can't set elements.") +end diff --git a/test/basics/test_basics.jl b/test/test_basics.jl similarity index 100% rename from test/basics/test_basics.jl rename to test/test_basics.jl diff --git a/test/basics/test_diagonal.jl b/test/test_diagonal.jl similarity index 68% rename from test/basics/test_diagonal.jl rename to test/test_diagonal.jl index ada6f0c..3c5d443 100644 --- a/test/basics/test_diagonal.jl +++ b/test/test_diagonal.jl @@ -11,13 +11,22 @@ using SparseArraysBase: using Test: @test, @testset +# compat with LTS: +@static if VERSION ≥ v"1.11" + _diagind = diagind +else + function _diagind(x::Diagonal, ::IndexCartesian) + return view(CartesianIndices(x), diagind(x)) + end +end + elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "Diagonal{$T}" for T in elts L = 4 D = Diagonal(rand(T, 4)) @test storedlength(D) == 4 - @test eachstoredindex(D) == diagind(D, IndexCartesian()) + @test eachstoredindex(D) == _diagind(D, IndexCartesian()) @test isstored(D, 2, 2) @test getstoredindex(D, 2, 2) == D[2, 2] @test !isstored(D, 2, 1) diff --git a/test/basics/test_exports.jl b/test/test_exports.jl similarity index 79% rename from test/basics/test_exports.jl rename to test/test_exports.jl index 3e9a6bf..56bde16 100644 --- a/test/basics/test_exports.jl +++ b/test/test_exports.jl @@ -6,8 +6,12 @@ using Test: @test, @testset :SparseArrayDOK, :SparseMatrixDOK, :SparseVectorDOK, + :OneElementArray, + :OneElementMatrix, + :OneElementVector, :eachstoredindex, :isstored, + :oneelementarray, :storedlength, :storedpairs, :storedvalues, diff --git a/test/basics/test_linalg.jl b/test/test_linalg.jl similarity index 100% rename from test/basics/test_linalg.jl rename to test/test_linalg.jl diff --git a/test/test_oneelementarray.jl b/test/test_oneelementarray.jl new file mode 100644 index 0000000..9fc5854 --- /dev/null +++ b/test/test_oneelementarray.jl @@ -0,0 +1,111 @@ +using SparseArraysBase: + OneElementArray, + OneElementMatrix, + OneElementVector, + eachstoredindex, + isstored, + oneelement, + storedlength, + storedpairs, + storedvalues +using Test: @test, @test_broken, @testset + +elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) +@testset "OneElementArray (eltype=$elt)" for elt in elts + for a in ( + OneElementArray((1, 2), (2, 2)), + OneElementMatrix((1, 2), (2, 2)), + OneElementArray((1, 2), Base.OneTo.((2, 2))), + OneElementMatrix((1, 2), Base.OneTo.((2, 2))), + OneElementArray(Base.OneTo(2) => 1, Base.OneTo(2) => 2), + OneElementMatrix(Base.OneTo(2) => 1, Base.OneTo(2) => 2), + oneelement((1, 2), (2, 2)), + oneelement((1, 2), Base.OneTo.((2, 2))), + oneelement(Base.OneTo(2) => 1, Base.OneTo(2) => 2), + ) + @test a isa OneElementArray{Bool,2} + @test a isa OneElementMatrix{Bool} + @test eltype(a) === Bool + @test size(a) == (2, 2) + @test length(a) == 4 + @test axes(a) == Base.OneTo.((2, 2)) + @test a[1, 1] === zero(Bool) + @test a[2, 1] === zero(Bool) + @test a[1, 2] === one(Bool) + @test a[2, 2] === zero(Bool) + @test storedlength(a) == 1 + @test collect(eachstoredindex(a)) == [CartesianIndex(1, 2)] + @test storedpairs(a) == [CartesianIndex(1, 2) => 1] + @test storedvalues(a) == [1] + end + + for a in (OneElementArray(1, 2), OneElementVector(1, 2)) + @test a isa OneElementArray{Bool,1} + @test a isa OneElementVector{Bool} + @test eltype(a) === Bool + @test a[1] === one(Bool) + @test a[2] === zero(Bool) + @test storedlength(a) == 1 + @test collect(eachstoredindex(a)) == [CartesianIndex(1)] + @test storedpairs(a) == [CartesianIndex(1) => 1] + @test storedvalues(a) == [1] + end + + a = OneElementArray() + @test eltype(a) === Bool + @test size(a) == () + @test a[] === one(Bool) + + a = OneElementArray{elt}() + @test eltype(a) === elt + @test size(a) == () + @test a[] === one(elt) + + for a in ( + OneElementArray{elt}((1, 2), (2, 2)), + OneElementMatrix{elt}((1, 2), (2, 2)), + OneElementArray(one(elt), (1, 2), (2, 2)), + OneElementMatrix(one(elt), (1, 2), (2, 2)), + OneElementArray{elt}((1, 2), Base.OneTo.((2, 2))), + OneElementMatrix{elt}((1, 2), Base.OneTo.((2, 2))), + OneElementArray(one(elt), (1, 2), Base.OneTo.((2, 2))), + OneElementMatrix(one(elt), (1, 2), Base.OneTo.((2, 2))), + OneElementArray{elt}(Base.OneTo(2) => 1, Base.OneTo(2) => 2), + OneElementMatrix{elt}(Base.OneTo(2) => 1, Base.OneTo(2) => 2), + OneElementArray(one(elt), Base.OneTo(2) => 1, Base.OneTo(2) => 2), + OneElementMatrix(one(elt), Base.OneTo(2) => 1, Base.OneTo(2) => 2), + oneelement(elt, (1, 2), (2, 2)), + oneelement(one(elt), (1, 2), (2, 2)), + oneelement(elt, (1, 2), Base.OneTo.((2, 2))), + oneelement(one(elt), (1, 2), Base.OneTo.((2, 2))), + oneelement(elt, Base.OneTo(2) => 1, Base.OneTo(2) => 2), + oneelement(one(elt), Base.OneTo(2) => 1, Base.OneTo(2) => 2), + ) + @test eltype(a) === elt + @test a[1, 1] === zero(elt) + @test a[2, 1] === zero(elt) + @test a[1, 2] === one(elt) + @test a[2, 2] === zero(elt) + end + + a = OneElementArray{elt}((1, 2), (2, 2)) + b = 2a + @test eltype(b) === elt + @test storedlength(b) == 1 + # TODO: Need to preserve that it is a `OneElementArray`. + # Currently falls back to constructing a `SparseArrayDOK`. + @test_broken b isa OneElementMatrix{elt} + @test b == 2 * Array(a) + + a1 = OneElementArray{elt}(2, (1, 2), (2, 2)) + a2 = OneElementArray{elt}(3, (2, 1), (2, 2)) + b = a1 * a2 + @test eltype(b) === elt + @test b[1, 1] === elt(6) + @test storedlength(b) == 1 + @test isstored(b, 1, 1) + # TODO: Need to preserve that it is a `OneElementArray`. + # Currently falls back to constructing a `SparseArrayDOK`. + @test_broken b isa OneElementMatrix{elt} + @test b == Array(a1) * Array(a2) +end diff --git a/test/basics/test_sparsearraydok.jl b/test/test_sparsearraydok.jl similarity index 100% rename from test/basics/test_sparsearraydok.jl rename to test/test_sparsearraydok.jl