diff --git a/Project.toml b/Project.toml index 50a5244..7b35a89 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.6" +version = "0.3.7" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -12,7 +12,7 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" [compat] ArrayLayouts = "1.10.4" -DerivableInterfaces = "0.4" +DerivableInterfaces = "0.5" FillArrays = "1.13.0" LinearAlgebra = "1.10.0" SparseArraysBase = "0.5" diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 126d702..3bc6f44 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -4,21 +4,26 @@ diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a) using DerivableInterfaces: DerivableInterfaces, @interface using SparseArraysBase: - SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle ## , StorageIndex, StorageIndices + SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle -abstract type AbstractDiagonalArrayInterface <: AbstractSparseArrayInterface end +abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end -struct DiagonalArrayInterface <: AbstractDiagonalArrayInterface end +struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end +DiagonalArrayInterface{M}(::Val{N}) where {M,N} = DiagonalArrayInterface{N}() +DiagionalArrayInterface(::Val{N}) where {N} = DiagonalArrayInterface{N}() +DiagonalArrayInterface() = DiagonalArrayInterface{Any}() -function DerivableInterfaces.arraytype(::AbstractDiagonalArrayInterface, elt::Type) - return DiagonalArray{elt} +function Base.similar(::AbstractDiagonalArrayInterface, elt::Type, ax::Tuple) + return similar(DiagonalArray{elt}, ax) +end +function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any,N}}) where {N} + return DiagonalArrayInterface{N}() end -DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray}) = DiagonalArrayInterface() abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end -function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle}) - return DiagonalArrayInterface() +function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle{N}}) where {N} + return DiagonalArrayInterface{N}() end struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end