diff --git a/Project.toml b/Project.toml index 68003a5..d842cdb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,12 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.26" +version = "0.3.27" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" @@ -21,11 +21,14 @@ DiagonalArraysNamedDimsArraysExt = "NamedDimsArrays" [compat] ArrayLayouts = "1.10.4" -DerivableInterfaces = "0.5.5" FillArrays = "1.13" +FunctionImplementations = "0.3.1" LinearAlgebra = "1.10" MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6" -NamedDimsArrays = "0.10, 0.11" -SparseArraysBase = "0.7.2" +NamedDimsArrays = "0.12" +SparseArraysBase = "0.8.1" julia = "1.10" + +[workspace] +projects = ["benchmark", "dev", "docs", "examples", "test"] diff --git a/docs/Project.toml b/docs/Project.toml index 228ca1b..c006235 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,9 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +[sources] +DiagonalArrays = {path = ".."} + [compat] DiagonalArrays = "0.3" Documenter = "1" diff --git a/examples/Project.toml b/examples/Project.toml index b559d07..a0dacaf 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,6 +2,9 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +DiagonalArrays = {path = ".."} + [compat] DiagonalArrays = "0.3" Test = "1" diff --git a/src/DiagonalArrays.jl b/src/DiagonalArrays.jl index 5140355..8ebb1c7 100644 --- a/src/DiagonalArrays.jl +++ b/src/DiagonalArrays.jl @@ -5,7 +5,6 @@ include("diaginterface/diaginterface.jl") include("diaginterface/diagindex.jl") include("diaginterface/diagindices.jl") include("abstractdiagonalarray/abstractdiagonalarray.jl") -include("abstractdiagonalarray/sparsearrayinterface.jl") include("abstractdiagonalarray/diagonalarraydiaginterface.jl") include("abstractdiagonalarray/arraylayouts.jl") include("diagonalarray/diagonalarray.jl") diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 40b2cf9..45f819f 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -2,52 +2,45 @@ diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a)})) -using DerivableInterfaces: DerivableInterfaces, @interface -using SparseArraysBase: - SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle +using FunctionImplementations: FunctionImplementations +using SparseArraysBase: SparseArraysBase as SA, AbstractSparseArrayStyle -abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end +abstract type AbstractDiagonalArrayStyle <: AbstractSparseArrayStyle end -struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end -DiagonalArrayInterface{M}(::Val{N}) where {M, N} = DiagonalArrayInterface{N}() -DiagionalArrayInterface(::Val{N}) where {N} = DiagonalArrayInterface{N}() -DiagonalArrayInterface() = DiagonalArrayInterface{Any}() +struct DiagonalArrayStyle <: AbstractDiagonalArrayStyle end +const diag_style = DiagonalArrayStyle() -function Base.similar(::AbstractDiagonalArrayInterface, elt::Type, ax::Tuple) - return similar(DiagonalArray{elt}, ax) +function FunctionImplementations.Style(::Type{<:AbstractDiagonalArray}) + return DiagonalArrayStyle() end -function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any, N}}) where {N} - return DiagonalArrayInterface{N}() -end - -abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end -function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle{N}}) where {N} - return DiagonalArrayInterface{N}() +module Broadcast + import SparseArraysBase as SA + abstract type AbstractDiagonalArrayStyle{N} <: SA.Broadcast.AbstractSparseArrayStyle{N} end + struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end + DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}() end -struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end - -DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}() - -function SparseArraysBase.isstored( - a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N} - ) where {N} - return allequal(I) -end -function SparseArraysBase.getstoredindex( - a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N} +using SparseArraysBase: getstoredindex +const getstoredindex_diag = diag_style(getstoredindex) +function getstoredindex_diag( + a::AbstractArray{<:Any, N}, I::Vararg{Int, N} ) where {N} # TODO: Make this check optional, define `checkstored` like `checkbounds` # in SparseArraysBase.jl. # allequal(I) || error("Not a diagonal index.") return getdiagindex(a, first(I)) end -function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any, 0}) +function getstoredindex_diag(a::AbstractArray{<:Any, 0}) return getdiagindex(a, 1) end -function SparseArraysBase.setstoredindex!( - a::AbstractDiagonalArray{<:Any, N}, value, I::Vararg{Int, N} +function getstoredindex_diag(a::AbstractArray, I::Int...) + return sparse_style(getstoredindex)(a, I...) +end +using SparseArraysBase: setstoredindex! +const setstoredindex!_diag = diag_style(setstoredindex!) +function setstoredindex!_diag( + a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N} ) where {N} # TODO: Make this check optional, define `checkstored` like `checkbounds` # in SparseArraysBase.jl. @@ -55,11 +48,13 @@ function SparseArraysBase.setstoredindex!( setdiagindex!(a, value, first(I)) return a end -function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any, 0}, value) +function setstoredindex!_diag(a::AbstractArray{<:Any, 0}, value) setdiagindex!(a, value, 1) return a end -function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray) +using SparseArraysBase: eachstoredindex +const eachstoredindex_diag = diag_style(eachstoredindex) +function eachstoredindex_diag(::IndexCartesian, a::AbstractArray) return diagindices(a) end @@ -84,8 +79,39 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex) return invoke(setindex!, Tuple{AbstractArray, Any, DiagIndex}, a, value, I) end -@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) - return DiagonalArrayStyle{ndims(type)}() +using SparseArraysBase: sparse_style +const getindex_diag = diag_style(getindex) +getindex_diag(a::AbstractArray, I...) = sparse_style(getindex)(a, I...) +const setindex!_diag = diag_style(setindex!) +setindex!_diag(a::AbstractArray, value, I...) = sparse_style(setindex!)(a, value, I...) +const copyto!_diag = diag_style(copyto!) +copyto!_diag(dst::AbstractArray, src::AbstractArray) = sparse_style(copyto!)(dst, src) +const map_diag = diag_style(map) +map_diag(f, as::AbstractArray...) = sparse_style(map)(f, as...) +const map!_diag = diag_style(map!) +map!_diag(f, dst::AbstractArray, as::AbstractArray...) = sparse_style(map!)(f, dst, as...) +const fill!_diag = diag_style(fill!) +fill!_diag(a::AbstractArray, value) = sparse_style(fill!)(a, value) +using FunctionImplementations: zero! +const zero!_diag = diag_style(zero!) +zero!_diag(a::AbstractArray) = sparse_style(zero!)(a) +using SparseArraysBase: isstored +const isstored_diag = diag_style(isstored) +function isstored_diag( + a::AbstractArray{<:Any, N}, I::Vararg{Int, N} + ) where {N} + return allequal(I) +end +isstored_diag(a::AbstractArray, I::Int...) = sparse_style(isstored)(a, I...) +using SparseArraysBase: storedvalues +const storedvalues_diag = diag_style(storedvalues) +storedvalues_diag(a::AbstractArray) = diagview(a) +using SparseArraysBase: storedpairs +const storedpairs_diag = diag_style(storedpairs) +storedpairs_diag(a::AbstractArray) = sparse_style(storedpairs)(a) + +function Base.Broadcast.BroadcastStyle(type::Type{<:AbstractDiagonalArray}) + return Broadcast.DiagonalArrayStyle{ndims(type)}() end using Base.Broadcast: Broadcasted, broadcasted @@ -99,10 +125,10 @@ function broadcasted_diagview(bc::Broadcasted) ) return broadcasted(m.f, map(diagview, m.args)...) end -function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle}) +function Base.copy(bc::Broadcasted{<:Broadcast.DiagonalArrayStyle}) return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc)) end -function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}) +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:Broadcast.DiagonalArrayStyle}) copyto!(diagview(dest), broadcasted_diagview(bc)) return dest end diff --git a/src/abstractdiagonalarray/sparsearrayinterface.jl b/src/abstractdiagonalarray/sparsearrayinterface.jl deleted file mode 100644 index 3f47a60..0000000 --- a/src/abstractdiagonalarray/sparsearrayinterface.jl +++ /dev/null @@ -1,19 +0,0 @@ -## # `SparseArraysBase` interface -## function SparseArraysBase.index_to_storage_index( -## a::AbstractDiagonalArray{<:Any,N}, I::CartesianIndex{N} -## ) where {N} -## !allequal(Tuple(I)) && return nothing -## return first(Tuple(I)) -## end -## -## function SparseArraysBase.storage_index_to_index(a::AbstractDiagonalArray, I) -## return CartesianIndex(ntuple(Returns(I), ndims(a))) -## end - -## # 1-dimensional case can be `AbstractDiagonalArray`. -## function SparseArraysBase.sparse_similar( -## a::AbstractDiagonalArray, elt::Type, dims::Tuple{Int} -## ) -## # TODO: Handle preserving zero element function. -## return similar(a, elt, dims) -## end diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 8a608c4..8f6c090 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -22,7 +22,7 @@ struct DiagonalArray{T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} <: end end -SparseArraysBase.unstored(a::DiagonalArray) = a.unstored +SA.unstored(a::DiagonalArray) = a.unstored Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) @@ -291,7 +291,8 @@ function Base.permutedims(a::DiagonalArray, perm) return DiagonalArray(copy(diagview(a)), ax_perm) end -function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) +using FunctionImplementations: FunctionImplementations +function FunctionImplementations.permuteddims(a::DiagonalArray, perm) ((ndims(a) == length(perm)) && isperm(perm)) || throw(ArgumentError("Not a valid permutation")) ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a)) @@ -300,7 +301,6 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) end # Scalar indexing. -using DerivableInterfaces: @interface, interface one_based_range(r) = false one_based_range(r::Base.OneTo) = true one_based_range(r::Base.Slice) = true @@ -335,8 +335,10 @@ function Base.view(a::DiagonalArray, I...) invoke(view, Tuple{AbstractArray, Vararg}, a, I′...) end end +using FunctionImplementations: style +using SparseArraysBase: sparse_style function Base.getindex(a::DiagonalArray, I::Int...) - return @interface interface(a) a[I...] + return sparse_style(getindex)(a, I...) end function Base.getindex(a::DiagonalArray, I::DiagIndex) return getdiagindex(a, index(I)) @@ -349,7 +351,7 @@ function Base.getindex(a::DiagonalArray, I...) I′ = to_indices(a, I) return if all(i -> i isa Real, I′) # Catch scalar indexing case. - @interface interface(a) a[I...] + return style(a)(getindex)(a, I...) elseif all(one_based_range, I′) _getindex_diag(a, I′...) else @@ -379,7 +381,7 @@ end # TODO: These definitions work around this issue: # https://github.com/JuliaArrays/FillArrays.jl/issues/416 # when the diagonal is a FillArrays.Ones or Zeros. -using Base.Broadcast: Broadcast, broadcast, broadcasted +using Base.Broadcast: broadcast, broadcasted using FillArrays: AbstractFill, Ones, Zeros _broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a) _broadcasted(::typeof(identity), a::Ones) = a @@ -407,8 +409,8 @@ _broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axe # Eager version of `_broadcasted`. _broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a)) -function Broadcast.broadcasted( - ::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag} +function Base.Broadcast.broadcasted( + ::Broadcast.DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag} ) where {F, T, N, Diag <: AbstractFill{T}} # TODO: Check that `f` preserves zeros? return DiagonalArray(_broadcasted(f, diagview(a)), axes(a)) diff --git a/test/Project.toml b/test/Project.toml index e3241d9..4d756c2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,9 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" @@ -14,18 +14,21 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +DiagonalArrays = {path = ".."} + [compat] Adapt = "4.4" Aqua = "0.8.9" -DerivableInterfaces = "0.5" DiagonalArrays = "0.3" FillArrays = "1" +FunctionImplementations = "0.3" JLArrays = "0.3" LinearAlgebra = "1" MatrixAlgebraKit = "0.2.5, 0.3, 0.4, 0.5, 0.6" -NamedDimsArrays = "0.10, 0.11" +NamedDimsArrays = "0.12" SafeTestsets = "0.1" -SparseArraysBase = "0.7.10" +SparseArraysBase = "0.8" StableRNGs = "1" Suppressor = "0.2" Test = "1" diff --git a/test/test_basics.jl b/test/test_basics.jl index 2ecd323..13b201d 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,4 +1,3 @@ -using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, ShapeInitializer, @@ -18,6 +17,7 @@ using DiagonalArrays: diagview, getdiagindices using FillArrays: Fill, Ones, Zeros +using FunctionImplementations: permuteddims using LinearAlgebra: Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength @@ -229,7 +229,7 @@ using Test: @test, @test_throws, @testset, @test_broken, @inferred @test diagview(b) ≢ diagview(a) @test size(b) === (4, 2, 3) end - @testset "DerivableInterfaces.permuteddims" begin + @testset "FunctionImplementations.permuteddims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) b = permuteddims(a, (3, 1, 2)) @test diagview(b) ≡ diagview(a)