Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
matrix:
pkg:
- 'BlockSparseArrays'
- 'KroneckerArrays'
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"
with:
localregistry: "https://github.com/ITensor/ITensorRegistry.git"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.17"
version = "0.3.18"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
24 changes: 24 additions & 0 deletions src/abstractdiagonalarray/abstractdiagonalarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@ abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end
const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2}
const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1}

# Define for type stability, for some reason the generic versions
# in SparseArraysBase.jl is not type stable.
# TODO: Investigate type stability of `iszero` in SparseArraysBase.jl.
function Base.iszero(a::AbstractDiagonalArray)
return iszero(diagview(a))
end

using FillArrays: AbstractFill, getindex_value
using LinearAlgebra: norm
# TODO: `_norm` works around:
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
# Change back to `norm` when that is fixed.
_norm(a, p::Int=2) = norm(a, p)
function _norm(a::AbstractFill, p::Int=2)
nrm1 = norm(getindex_value(a))
return (length(a))^(1/oftype(nrm1, p)) * nrm1
end
function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2)
# TODO: `_norm` works around:
# https://github.com/JuliaArrays/FillArrays.jl/issues/417
# Change back to `norm` when that is fixed.
return _norm(diagview(a), p)
end

using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric
LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a)
function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number})
Expand Down
29 changes: 7 additions & 22 deletions src/abstractdiagonalarray/diagonalarraydiaginterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ function SparseArraysBase.getstoredindex(
# allequal(I) || error("Not a diagonal index.")
return getdiagindex(a, first(I))
end
function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any,0})
return getdiagindex(a, 1)
end
function SparseArraysBase.setstoredindex!(
a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
Expand All @@ -52,6 +55,10 @@ function SparseArraysBase.setstoredindex!(
setdiagindex!(a, value, first(I))
return a
end
function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any,0}, value)
setdiagindex!(a, value, 1)
return a
end
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray)
return diagindices(a)
end
Expand Down Expand Up @@ -99,25 +106,3 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}
copyto!(diagview(dest), broadcasted_diagview(bc))
return dest
end

## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))

## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
## return a[StorageIndex(i)]
## end

## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex)
## a[StorageIndex(i)] = value
## return a
## end

## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i))

## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices)
## return a[StorageIndices(i)]
## end

## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices)
## a[StorageIndices(i)] = value
## return a
## end
1 change: 1 addition & 0 deletions src/diaginterface/diaginterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ function setdiagindex!(a::AbstractArray, v, i::Integer)
end

function getdiagindices(a::AbstractArray, I)
# TODO: Should this be a view?
return @view diagview(a)[I]
end

Expand Down
225 changes: 208 additions & 17 deletions src/diagonalarray/diagonalarray.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,112 @@
using FillArrays: Zeros
using SparseArraysBase: Unstored, unstored

function _DiagonalArray end
diaglength_from_shape(sz::Tuple{Integer,Vararg{Integer}}) = minimum(sz)
function diaglength_from_shape(
sz::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
)
return minimum(length, sz)
end
diaglength_from_shape(sz::Tuple{}) = 1

struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <:
struct DiagonalArray{T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} <:
AbstractDiagonalArray{T,N}
diag::Diag
unstored::Unstored
global @inline function _DiagonalArray(
diag::Diag, unstored::Unstored
) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}}
length(diag) == minimum(size(unstored)) ||
diag::D
unstored::U
function DiagonalArray{T,N,D,U}(
diag::AbstractVector, unstored::Unstored
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
length(diag) == diaglength_from_shape(size(unstored)) ||
throw(ArgumentError("Length of diagonals doesn't match dimensions"))
return new{T,N,Diag,Unstored}(diag, unstored)
return new{T,N,D,U}(diag, parent(unstored))
end
end

SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
Base.size(a::DiagonalArray) = size(unstored(a))
Base.axes(a::DiagonalArray) = axes(unstored(a))

function DiagonalArray{T,N,D}(
diag::D, unstored::Unstored{T,N,U}
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
return DiagonalArray{T,N,D,U}(diag, unstored)
end
function DiagonalArray{T,N}(
diag::D, unstored::Unstored{T,N}
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(diag, unstored)
end
function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N}
return DiagonalArray{T,N}(diag, unstored)
end
function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T}
return DiagonalArray{T}(diag, unstored)
end

function DiagonalArray(::UndefInitializer, unstored::Unstored)
return _DiagonalArray(
Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored)
return DiagonalArray(
Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored
)
end

# Indicate we will construct an array just from the shape,
# for example for a Base.OneTo or FillArrays.Ones or Zeros.
# All the elements should be uniquely defined by the input axes.
struct ShapeInitializer end

# This is used to create custom constructors for arrays,
# in this case a generic constructor of a vector from a length.
function construct(vect::Type{<:AbstractVector}, ::ShapeInitializer, len::Integer)
if applicable(vect, len)
return vect(len)
elseif applicable(vect, (Base.OneTo(len),))
return vect((Base.OneTo(len),))
else
error(lazy"Can't construct $(vect) from length.")
end
end

# This helps to support diagonals where the elements are known
# from the types, for example diagonals that are `Zeros` and `Ones`.
function DiagonalArray{T,N,D}(
init::ShapeInitializer, unstored::Unstored
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(
construct(D, init, diaglength_from_shape(axes(unstored))), unstored
)
end

# Constructors accepting axes.
# This helps to support diagonals where the elements are known
# from the types, for example diagonals that are `Zeros` and `Ones`.
# These versions use the default unstored type `Zeros{T,N}`.
function DiagonalArray{T,N,D}(
init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax)))
end
function DiagonalArray{T,N,D}(
init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}...
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(init, ax)
end
function DiagonalArray{T,N,D}(
init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}}
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(init, Base.OneTo.(sz))
end
function DiagonalArray{T,N,D}(
init::ShapeInitializer, sz1::Integer, sz_rest::Integer...
) where {T,N,D<:AbstractVector{T}}
return DiagonalArray{T,N,D}(init, (sz1, sz_rest...))
end

# Constructor from diagonal entries accepting axes.
function DiagonalArray{T,N}(
diag::AbstractVector,
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
) where {T,N}
N == length(ax) || throw(ArgumentError("Wrong number of axes"))
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax))
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax)))
end
function DiagonalArray{T,N}(
diag::AbstractVector,
Expand Down Expand Up @@ -97,7 +171,7 @@ function DiagonalArray{T}(
end

function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N}
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims))
return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims)))
end

function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N}
Expand Down Expand Up @@ -146,7 +220,7 @@ end

# undef
function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims)
return DiagonalArray{T,N}(Vector{T}(undef, diaglength_from_shape(dims)), dims)
end

function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
Expand All @@ -162,8 +236,10 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
end

# Axes version
function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N}
return DiagonalArray{T,N}(undef, length.(axes))
function DiagonalArray{T}(
::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}}
) where {T}
return DiagonalArray{T,length(axes)}(undef, length.(axes))
end

function Base.similar(a::DiagonalArray, unstored::Unstored)
Expand Down Expand Up @@ -197,3 +273,118 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
# Unlike `permutedims(::Diagonal, perm)`, we copy here.
return DiagonalArray(diagview(a), ax_perm)
end

# Scalar indexing.
using DerivableInterfaces: @interface, interface
one_based_range(r) = false
one_based_range(r::Base.OneTo) = true
one_based_range(r::Base.Slice) = true
function _diag_axes(a::DiagonalArray, I...)
return map(ntuple(identity, ndims(a))) do d
return Base.axes1(axes(a, d)[I[d]])
end
end
# A view that preserves the diagonal structure.
function _view_diag(a::DiagonalArray, I...)
ax = _diag_axes(a, I...)
return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax)
end
function _view_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...)
ax = _diag_axes(a, I1, Irest...)
return DiagonalArray(view(diagview(a), :), ax)
end
# A slice that preserves the diagonal structure.
function _getindex_diag(a::DiagonalArray, I...)
ax = _diag_axes(a, I...)
return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax)
end
function _getindex_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...)
ax = _diag_axes(a, I1, Irest...)
return DiagonalArray(diagview(a)[:], ax)
end
function Base.view(a::DiagonalArray, I...)
I′ = to_indices(a, I)
return if all(one_based_range, I′)
_view_diag(a, I′...)
else
invoke(view, Tuple{AbstractArray,Vararg}, a, I′...)
end
end
function Base.getindex(a::DiagonalArray, I::Int...)
return @interface interface(a) a[I...]
end
function Base.getindex(a::DiagonalArray, I::DiagIndex)
return getdiagindex(a, index(I))
end
function Base.getindex(a::DiagonalArray, I::DiagIndices)
# TODO: Should this be a view?
return @view diagview(a)[indices(I)]
end
function Base.getindex(a::DiagonalArray, I...)
I′ = to_indices(a, I)
return if all(i -> i isa Real, I′)
# Catch scalar indexing case.
@interface interface(a) a[I...]
elseif all(one_based_range, I′)
_getindex_diag(a, I′...)
else
copy(view(a, I′...))
end
end

# Define in order to preserve immutable diagonals such as FillArrays.
function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N}
# TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`:
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112
return a
end
function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
return DiagonalArray{T,N}(diagview(a))
end
function DiagonalArray{T}(a::DiagonalArray) where {T}
return DiagonalArray{T,ndims(a)}(a)
end
function DiagonalArray(a::DiagonalArray)
return DiagonalArray{eltype(a),ndims(a)}(a)
end
function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N}
return DiagonalArray{T,N}(a)
end

# TODO: These definitions work around this issue:
# https://github.com/JuliaArrays/FillArrays.jl/issues/416
# when the diagonal is a FillArrays.Ones or Zeros.
using Base.Broadcast: Broadcast, broadcast, broadcasted
using FillArrays: AbstractFill, Ones, Zeros
_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a)
_broadcasted(::typeof(identity), a::Ones) = a
_broadcasted(::typeof(identity), a::Zeros) = a
_broadcasted(::typeof(complex), a::Ones) = Ones{complex(eltype(a))}(axes(a))
_broadcasted(::typeof(complex), a::Zeros) = Zeros{complex(eltype(a))}(axes(a))
_broadcasted(elt::Type, a::Ones) = Ones{elt}(axes(a))
_broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a))
_broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
using LinearAlgebra: pinv
_broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a)
_broadcasted(::typeof(pinv), a::Zeros) = _broadcasted(typeof(inv(zero(eltype(a)))), a)
_broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a)
_broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a)
_broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a)
_broadcasted(::typeof(cbrt), a::Zeros) = _broadcasted(typeof(cbrt(zero(eltype(a)))), a)
_broadcasted(::typeof(exp), a::Zeros) = Ones{typeof(exp(zero(eltype(a))))}(axes(a))
_broadcasted(::typeof(cis), a::Zeros) = Ones{typeof(cis(zero(eltype(a))))}(axes(a))
_broadcasted(::typeof(log), a::Ones) = Zeros{typeof(log(one(eltype(a))))}(axes(a))
_broadcasted(::typeof(cos), a::Zeros) = Ones{typeof(cos(zero(eltype(a))))}(axes(a))
_broadcasted(::typeof(sin), a::Zeros) = _broadcasted(typeof(sin(zero(eltype(a)))), a)
_broadcasted(::typeof(tan), a::Zeros) = _broadcasted(typeof(tan(zero(eltype(a)))), a)
_broadcasted(::typeof(sec), a::Zeros) = Ones{typeof(sec(zero(eltype(a))))}(axes(a))
_broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axes(a))
# Eager version of `_broadcasted`.
_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a))

function Broadcast.broadcasted(
::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag}
) where {F,T,N,Diag<:AbstractFill{T}}
# TODO: Check that `f` preserves zeros?
return DiagonalArray(_broadcasted(f, diagview(a)), axes(a))
end
Loading
Loading