diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 381cd67..2e658d7 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -18,6 +18,7 @@ jobs: matrix: pkg: - 'BlockSparseArrays' + - 'KroneckerArrays' uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main" with: localregistry: "https://github.com/ITensor/ITensorRegistry.git" diff --git a/Project.toml b/Project.toml index 0909e7e..d37916b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.17" +version = "0.3.18" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/abstractdiagonalarray/abstractdiagonalarray.jl b/src/abstractdiagonalarray/abstractdiagonalarray.jl index 1c23777..6fca242 100644 --- a/src/abstractdiagonalarray/abstractdiagonalarray.jl +++ b/src/abstractdiagonalarray/abstractdiagonalarray.jl @@ -4,6 +4,30 @@ abstract type AbstractDiagonalArray{T,N} <: AbstractSparseArray{T,N} end const AbstractDiagonalMatrix{T} = AbstractDiagonalArray{T,2} const AbstractDiagonalVector{T} = AbstractDiagonalArray{T,1} +# Define for type stability, for some reason the generic versions +# in SparseArraysBase.jl is not type stable. +# TODO: Investigate type stability of `iszero` in SparseArraysBase.jl. +function Base.iszero(a::AbstractDiagonalArray) + return iszero(diagview(a)) +end + +using FillArrays: AbstractFill, getindex_value +using LinearAlgebra: norm +# TODO: `_norm` works around: +# https://github.com/JuliaArrays/FillArrays.jl/issues/417 +# Change back to `norm` when that is fixed. +_norm(a, p::Int=2) = norm(a, p) +function _norm(a::AbstractFill, p::Int=2) + nrm1 = norm(getindex_value(a)) + return (length(a))^(1/oftype(nrm1, p)) * nrm1 +end +function LinearAlgebra.norm(a::AbstractDiagonalArray, p::Int=2) + # TODO: `_norm` works around: + # https://github.com/JuliaArrays/FillArrays.jl/issues/417 + # Change back to `norm` when that is fixed. + return _norm(diagview(a), p) +end + using LinearAlgebra: LinearAlgebra, ishermitian, isposdef, issymmetric LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Real}) = issquare(a) function LinearAlgebra.ishermitian(a::AbstractDiagonalMatrix{<:Number}) diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 3f914d3..d9f203d 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -43,6 +43,9 @@ function SparseArraysBase.getstoredindex( # allequal(I) || error("Not a diagonal index.") return getdiagindex(a, first(I)) end +function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any,0}) + return getdiagindex(a, 1) +end function SparseArraysBase.setstoredindex!( a::AbstractDiagonalArray{<:Any,N}, value, I::Vararg{Int,N} ) where {N} @@ -52,6 +55,10 @@ function SparseArraysBase.setstoredindex!( setdiagindex!(a, value, first(I)) return a end +function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any,0}, value) + setdiagindex!(a, value, 1) + return a +end function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray) return diagindices(a) end @@ -99,25 +106,3 @@ function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle} copyto!(diagview(dest), broadcasted_diagview(bc)) return dest end - -## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i)) - -## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex) -## return a[StorageIndex(i)] -## end - -## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndex) -## a[StorageIndex(i)] = value -## return a -## end - -## SparseArraysBase.StorageIndices(i::DiagIndices) = StorageIndices(indices(i)) - -## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndices) -## return a[StorageIndices(i)] -## end - -## function Base.setindex!(a::AbstractDiagonalArray, value, i::DiagIndices) -## a[StorageIndices(i)] = value -## return a -## end diff --git a/src/diaginterface/diaginterface.jl b/src/diaginterface/diaginterface.jl index ab572c9..dcbc331 100644 --- a/src/diaginterface/diaginterface.jl +++ b/src/diaginterface/diaginterface.jl @@ -97,6 +97,7 @@ function setdiagindex!(a::AbstractArray, v, i::Integer) end function getdiagindices(a::AbstractArray, I) + # TODO: Should this be a view? return @view diagview(a)[I] end diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index 47ce8b2..7e0c7cd 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -1,18 +1,24 @@ using FillArrays: Zeros using SparseArraysBase: Unstored, unstored -function _DiagonalArray end +diaglength_from_shape(sz::Tuple{Integer,Vararg{Integer}}) = minimum(sz) +function diaglength_from_shape( + sz::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} +) + return minimum(length, sz) +end +diaglength_from_shape(sz::Tuple{}) = 1 -struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <: +struct DiagonalArray{T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} <: AbstractDiagonalArray{T,N} - diag::Diag - unstored::Unstored - global @inline function _DiagonalArray( - diag::Diag, unstored::Unstored - ) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} - length(diag) == minimum(size(unstored)) || + diag::D + unstored::U + function DiagonalArray{T,N,D,U}( + diag::AbstractVector, unstored::Unstored + ) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + length(diag) == diaglength_from_shape(size(unstored)) || throw(ArgumentError("Length of diagonals doesn't match dimensions")) - return new{T,N,Diag,Unstored}(diag, unstored) + return new{T,N,D,U}(diag, parent(unstored)) end end @@ -20,19 +26,87 @@ SparseArraysBase.unstored(a::DiagonalArray) = a.unstored Base.size(a::DiagonalArray) = size(unstored(a)) Base.axes(a::DiagonalArray) = axes(unstored(a)) +function DiagonalArray{T,N,D}( + diag::D, unstored::Unstored{T,N,U} +) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}} + return DiagonalArray{T,N,D,U}(diag, unstored) +end +function DiagonalArray{T,N}( + diag::D, unstored::Unstored{T,N} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}(diag, unstored) +end +function DiagonalArray{T}(diag::AbstractVector{T}, unstored::Unstored{T,N}) where {T,N} + return DiagonalArray{T,N}(diag, unstored) +end +function DiagonalArray(diag::AbstractVector{T}, unstored::Unstored{T}) where {T} + return DiagonalArray{T}(diag, unstored) +end + function DiagonalArray(::UndefInitializer, unstored::Unstored) - return _DiagonalArray( - Vector{eltype(unstored)}(undef, minimum(size(unstored))), parent(unstored) + return DiagonalArray( + Vector{eltype(unstored)}(undef, diaglength_from_shape(size(unstored))), unstored + ) +end + +# Indicate we will construct an array just from the shape, +# for example for a Base.OneTo or FillArrays.Ones or Zeros. +# All the elements should be uniquely defined by the input axes. +struct ShapeInitializer end + +# This is used to create custom constructors for arrays, +# in this case a generic constructor of a vector from a length. +function construct(vect::Type{<:AbstractVector}, ::ShapeInitializer, len::Integer) + if applicable(vect, len) + return vect(len) + elseif applicable(vect, (Base.OneTo(len),)) + return vect((Base.OneTo(len),)) + else + error(lazy"Can't construct $(vect) from length.") + end +end + +# This helps to support diagonals where the elements are known +# from the types, for example diagonals that are `Zeros` and `Ones`. +function DiagonalArray{T,N,D}( + init::ShapeInitializer, unstored::Unstored +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}( + construct(D, init, diaglength_from_shape(axes(unstored))), unstored ) end -# Constructors accepting axes. +# This helps to support diagonals where the elements are known +# from the types, for example diagonals that are `Zeros` and `Ones`. +# These versions use the default unstored type `Zeros{T,N}`. +function DiagonalArray{T,N,D}( + init::ShapeInitializer, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}(init, Unstored(Zeros{T,N}(ax))) +end +function DiagonalArray{T,N,D}( + init::ShapeInitializer, ax::AbstractUnitRange{<:Integer}... +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}(init, ax) +end +function DiagonalArray{T,N,D}( + init::ShapeInitializer, sz::Tuple{Integer,Vararg{Integer}} +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}(init, Base.OneTo.(sz)) +end +function DiagonalArray{T,N,D}( + init::ShapeInitializer, sz1::Integer, sz_rest::Integer... +) where {T,N,D<:AbstractVector{T}} + return DiagonalArray{T,N,D}(init, (sz1, sz_rest...)) +end + +# Constructor from diagonal entries accepting axes. function DiagonalArray{T,N}( diag::AbstractVector, ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) where {T,N} N == length(ax) || throw(ArgumentError("Wrong number of axes")) - return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(ax)) + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(ax))) end function DiagonalArray{T,N}( diag::AbstractVector, @@ -97,7 +171,7 @@ function DiagonalArray{T}( end function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N} - return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims)) + return DiagonalArray(convert(AbstractVector{T}, diag), Unstored(Zeros{T}(dims))) end function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N} @@ -146,7 +220,7 @@ end # undef function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} - return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims) + return DiagonalArray{T,N}(Vector{T}(undef, diaglength_from_shape(dims)), dims) end function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} @@ -162,8 +236,10 @@ function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N} end # Axes version -function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N} - return DiagonalArray{T,N}(undef, length.(axes)) +function DiagonalArray{T}( + ::UndefInitializer, axes::Tuple{Base.OneTo{Int},Vararg{Base.OneTo{Int}}} +) where {T} + return DiagonalArray{T,length(axes)}(undef, length.(axes)) end function Base.similar(a::DiagonalArray, unstored::Unstored) @@ -197,3 +273,118 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm) # Unlike `permutedims(::Diagonal, perm)`, we copy here. return DiagonalArray(diagview(a), ax_perm) end + +# Scalar indexing. +using DerivableInterfaces: @interface, interface +one_based_range(r) = false +one_based_range(r::Base.OneTo) = true +one_based_range(r::Base.Slice) = true +function _diag_axes(a::DiagonalArray, I...) + return map(ntuple(identity, ndims(a))) do d + return Base.axes1(axes(a, d)[I[d]]) + end +end +# A view that preserves the diagonal structure. +function _view_diag(a::DiagonalArray, I...) + ax = _diag_axes(a, I...) + return DiagonalArray(view(diagview(a), Base.OneTo(minimum(length, I))), ax) +end +function _view_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(view(diagview(a), :), ax) +end +# A slice that preserves the diagonal structure. +function _getindex_diag(a::DiagonalArray, I...) + ax = _diag_axes(a, I...) + return DiagonalArray(diagview(a)[Base.OneTo(minimum(length, I))], ax) +end +function _getindex_diag(a::DiagonalArray, I1::Base.Slice, Irest::Base.Slice...) + ax = _diag_axes(a, I1, Irest...) + return DiagonalArray(diagview(a)[:], ax) +end +function Base.view(a::DiagonalArray, I...) + I′ = to_indices(a, I) + return if all(one_based_range, I′) + _view_diag(a, I′...) + else + invoke(view, Tuple{AbstractArray,Vararg}, a, I′...) + end +end +function Base.getindex(a::DiagonalArray, I::Int...) + return @interface interface(a) a[I...] +end +function Base.getindex(a::DiagonalArray, I::DiagIndex) + return getdiagindex(a, index(I)) +end +function Base.getindex(a::DiagonalArray, I::DiagIndices) + # TODO: Should this be a view? + return @view diagview(a)[indices(I)] +end +function Base.getindex(a::DiagonalArray, I...) + I′ = to_indices(a, I) + return if all(i -> i isa Real, I′) + # Catch scalar indexing case. + @interface interface(a) a[I...] + elseif all(one_based_range, I′) + _getindex_diag(a, I′...) + else + copy(view(a, I′...)) + end +end + +# Define in order to preserve immutable diagonals such as FillArrays. +function DiagonalArray{T,N}(a::DiagonalArray{T,N}) where {T,N} + # TODO: Should this copy? This matches the design of `LinearAlgebra.Diagonal`: + # https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L110-L112 + return a +end +function DiagonalArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} + return DiagonalArray{T,N}(diagview(a)) +end +function DiagonalArray{T}(a::DiagonalArray) where {T} + return DiagonalArray{T,ndims(a)}(a) +end +function DiagonalArray(a::DiagonalArray) + return DiagonalArray{eltype(a),ndims(a)}(a) +end +function Base.AbstractArray{T,N}(a::DiagonalArray{<:Any,N}) where {T,N} + return DiagonalArray{T,N}(a) +end + +# TODO: These definitions work around this issue: +# https://github.com/JuliaArrays/FillArrays.jl/issues/416 +# when the diagonal is a FillArrays.Ones or Zeros. +using Base.Broadcast: Broadcast, broadcast, broadcasted +using FillArrays: AbstractFill, Ones, Zeros +_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a) +_broadcasted(::typeof(identity), a::Ones) = a +_broadcasted(::typeof(identity), a::Zeros) = a +_broadcasted(::typeof(complex), a::Ones) = Ones{complex(eltype(a))}(axes(a)) +_broadcasted(::typeof(complex), a::Zeros) = Zeros{complex(eltype(a))}(axes(a)) +_broadcasted(elt::Type, a::Ones) = Ones{elt}(axes(a)) +_broadcasted(elt::Type, a::Zeros) = Zeros{elt}(axes(a)) +_broadcasted(::typeof(inv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) +using LinearAlgebra: pinv +_broadcasted(::typeof(pinv), a::Ones) = _broadcasted(typeof(inv(oneunit(eltype(a)))), a) +_broadcasted(::typeof(pinv), a::Zeros) = _broadcasted(typeof(inv(zero(eltype(a)))), a) +_broadcasted(::typeof(sqrt), a::Ones) = _broadcasted(typeof(sqrt(one(eltype(a)))), a) +_broadcasted(::typeof(sqrt), a::Zeros) = _broadcasted(typeof(sqrt(zero(eltype(a)))), a) +_broadcasted(::typeof(cbrt), a::Ones) = _broadcasted(typeof(cbrt(one(eltype(a)))), a) +_broadcasted(::typeof(cbrt), a::Zeros) = _broadcasted(typeof(cbrt(zero(eltype(a)))), a) +_broadcasted(::typeof(exp), a::Zeros) = Ones{typeof(exp(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cis), a::Zeros) = Ones{typeof(cis(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(log), a::Ones) = Zeros{typeof(log(one(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cos), a::Zeros) = Ones{typeof(cos(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(sin), a::Zeros) = _broadcasted(typeof(sin(zero(eltype(a)))), a) +_broadcasted(::typeof(tan), a::Zeros) = _broadcasted(typeof(tan(zero(eltype(a)))), a) +_broadcasted(::typeof(sec), a::Zeros) = Ones{typeof(sec(zero(eltype(a))))}(axes(a)) +_broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axes(a)) +# Eager version of `_broadcasted`. +_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a)) + +function Broadcast.broadcasted( + ::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T,N,Diag} +) where {F,T,N,Diag<:AbstractFill{T}} + # TODO: Check that `f` preserves zeros? + return DiagonalArray(_broadcasted(f, diagview(a)), axes(a)) +end diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 4eb6fab..d3a310d 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -58,3 +58,91 @@ function LinearAlgebra.mul!( d_dest .= d1 .* d2 .* α .+ d_dest .* β return a_dest end + +# Adapted from https://github.com/JuliaLang/LinearAlgebra.jl/blob/release-1.12/src/diagonal.jl#L866-L928. +function LinearAlgebra.tr(a::DiagonalMatrix) + checksquare(a) + # TODO: Define as `sum(tr, diagview(a))` like LinearAlgebra.jl? + return sum(diagview(a)) +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.det(a::DiagonalMatrix) + checksquare(a) + # TODO: Define as `prod(det, diagview(a))` like LinearAlgebra.jl? + return prod(diagview(a)) +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.logabsdet(a::DiagonalMatrix) + checksquare(a) + return mapreduce(((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), diagview(a)) do x + return (log(abs(x)), sign(x)) + end +end +# TODO: Special case for FillArrays diagonals. +function LinearAlgebra.logdet(a::DiagonalMatrix{<:Complex}) + checksquare(a) + z = sum(log, diagview(a)) + return complex(real(z), rem2pi(imag(z), RoundNearest)) +end + +# Matrix functions +for f in [ + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + @eval begin + function Base.$f(a::DiagonalMatrix) + checksquare(a) + return DiagonalMatrix(_broadcast($f, diagview(a)), axes(a)) + end + end +end + +# Cube root of a real-valued diagonal matrix +function Base.cbrt(a::DiagonalMatrix{<:Real}) + checksquare(a) + return DiagonalMatrix(_broadcast(cbrt, diagview(a)), axes(a)) +end + +function LinearAlgebra.inv(a::DiagonalMatrix) + checksquare(a) + # `DiagonalArrays._broadcast` works around issues like https://github.com/JuliaArrays/FillArrays.jl/issues/416 + # when the diagonal is a FillArray or similar lazy array. + d⁻¹ = _broadcast(inv, diagview(a)) + any(isinf, d⁻¹) && error("Singular Exception") + return DiagonalMatrix(d⁻¹, axes(a)) +end + +# TODO: Support `atol` and `rtol` keyword arguments: +# https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.pinv +using LinearAlgebra: pinv +function LinearAlgebra.pinv(a::DiagonalMatrix) + checksquare(a) + return DiagonalMatrix(_broadcast(pinv, diagview(a)), axes(a)) +end diff --git a/src/dual.jl b/src/dual.jl index 9780a3b..36b6ee5 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -1,3 +1,11 @@ # TODO: Define `TensorProducts.dual`. dual(x) = x issquare(a::AbstractMatrix) = (axes(a, 1) == dual(axes(a, 2))) +# Like `LinearAlgebra.checksquare` but based on `DiagonalArrays.issquare`, +# which checks the axes and allows customizing to check that the +# codomain is the dual of the domain. +# Returns the codomain if the check passes. +function checksquare(a::AbstractMatrix) + issquare(a) || throw(DimensionMismatch(lazy"matrix is not square: axes are $(axes(a))")) + return axes(a, 1) +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 4bc24cb..de1dce7 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,23 +1,27 @@ -using Test: @test, @testset, @test_broken, @inferred using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, + ShapeInitializer, Delta, DeltaMatrix, DiagonalArray, DiagonalMatrix, ScaledDelta, ScaledDeltaMatrix, + Unstored, δ, delta, diagindices, diaglength, diagonal, diagonaltype, - diagview -using FillArrays: Fill, Ones -using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength -using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric + diagview, + getdiagindices +using FillArrays: Fill, Ones, Zeros +using LinearAlgebra: + Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr +using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength +using Test: @test, @test_throws, @testset, @test_broken, @inferred @testset "Test DiagonalArrays" begin @testset "DiagonalArray (eltype=$elt)" for elt in ( @@ -101,10 +105,116 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric eltype(DiagonalArray{elt}(undef, (2, 2))) ≡ eltype(DiagonalArray{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ eltype(DiagonalArray{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) ≡ - eltype(DiagonalArray{elt,2}(undef, 2, 2)) ≡ - eltype(DiagonalArray{elt,2}(undef, (2, 2))) ≡ - eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ - eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2)))) + eltype(DiagonalMatrix{elt}(undef, 2, 2)) ≡ + eltype(DiagonalMatrix{elt}(undef, (2, 2))) ≡ + eltype(DiagonalMatrix{elt}(undef, Base.OneTo(2), Base.OneTo(2))) ≡ + eltype(DiagonalMatrix{elt}(undef, (Base.OneTo(2), Base.OneTo(2)))) + + # Special constructors for immutable diagonal. + init = ShapeInitializer() + @test DiagonalMatrix(Base.OneTo(UInt32(2))) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Base.OneTo.((2, 2))...) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, (2, 2)) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, 2, 2) ≡ + DiagonalMatrix{UInt32,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}(2, 2))) + + init = ShapeInitializer() + @test DiagonalMatrix(Ones{elt}(2)) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, Base.OneTo.((2, 2))) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( + init, Base.OneTo.((2, 2))... + ) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, (2, 2)) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}(init, 2, 2) ≡ + DiagonalMatrix{elt,Ones{elt,1,Tuple{Base.OneTo{Int}}}}( + init, Unstored(Zeros{elt}(2, 2)) + ) + + init = ShapeInitializer() + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, Base.OneTo.((2, 2))) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}( + init, Base.OneTo.((2, 2))... + ) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, (2, 2)) + @test_throws ErrorException DiagonalMatrix{elt,Vector{elt}}(init, 2, 2) + + # 0-dim constructors + v = randn(elt, 1) + @test DiagonalArray(v) ≡ + DiagonalArray(v, ()) ≡ + DiagonalArray{elt}(v) ≡ + DiagonalArray{elt}(v, ()) ≡ + DiagonalArray{elt,0}(v) ≡ + DiagonalArray{elt,0}(v, ()) + @test size(DiagonalArray{elt}(undef)) ≡ + size(DiagonalArray{elt}(undef, ())) ≡ + size(DiagonalArray{elt,0}(undef)) ≡ + size(DiagonalArray{elt,0}(undef, ())) + @test elt ≡ + eltype(DiagonalArray{elt}(undef)) ≡ + eltype(DiagonalArray{elt}(undef, ())) ≡ + eltype(DiagonalArray{elt,0}(undef)) ≡ + eltype(DiagonalArray{elt,0}(undef, ())) + + # Special constructors for immutable diagonal. + init = ShapeInitializer() + @test DiagonalArray{<:Any,0}(Base.OneTo(UInt32(1))) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, ()) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init) ≡ + DiagonalArray{UInt32,0,Base.OneTo{UInt32}}(init, Unstored(Zeros{UInt32}())) + end + @testset "0-dim operations" begin + diag = randn(elt, 1) + a = DiagonalArray(diag) + @test a[] == diag[1] + a[] = 2 + @test a[] == 2 + end + @testset "Conversion" begin + a = DiagonalMatrix(randn(elt, 2)) + @test DiagonalMatrix{elt}(a) ≡ a + @test DiagonalMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test DiagonalArray(a) ≡ a + @test AbstractMatrix{elt}(a) ≡ a + @test AbstractMatrix{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + @test AbstractArray{elt}(a) ≡ a + @test AbstractArray{ComplexF64}(a) == DiagonalMatrix(ComplexF64.(diagview(a))) + end + @testset "Slicing" begin + # Slicing that preserves the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[:, :] + @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} + @test diagview(b) ≡ view(diagview(a), :) + + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[Base.OneTo(2), Base.OneTo(2)] + @test b isa DiagonalMatrix{elt,<:SubArray{elt,1}} + @test diagview(b) ≡ view(diagview(a), Base.OneTo(2)) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[:, :] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a) + + a = DiagonalMatrix(randn(elt, 3)) + b = a[Base.OneTo(2), Base.OneTo(2)] + @test typeof(b) == typeof(a) + @test diagview(b) == diagview(a)[Base.OneTo(2)] + + # Slicing that doesn't preserve the diagonal structure. + a = DiagonalMatrix(randn(elt, 3)) + b = @view a[2:3, 2:3] + @test b isa SubArray + @test b == Matrix(a)[2:3, 2:3] + + a = DiagonalMatrix(randn(elt, 3)) + b = a[2:3, 2:3] + @test b isa SparseMatrixDOK{elt} + @test b == Matrix(a)[2:3, 2:3] + @test storedlength(b) == 2 end @testset "permutedims" begin a = DiagonalArray(randn(elt, 2), (2, 3, 4)) @@ -134,6 +244,51 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric # Non-zero-preserving functions not supported yet. c = DiagonalArray{elt}(undef, (2, 3)) @test_broken c .= a .+ 2 + + a_ones = DiagonalMatrix(Ones{elt}(2)) + a_zeros = DiagonalMatrix(Zeros{elt}(2)) + @test identity.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test identity.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test complex.(a_ones) ≡ DiagonalMatrix(Ones{complex(elt)}(2)) + @test complex.(a_zeros) ≡ DiagonalMatrix(Zeros{complex(elt)}(2)) + @test Float32.(a_ones) ≡ DiagonalMatrix(Ones{Float32}(2)) + @test Float32.(a_zeros) ≡ DiagonalMatrix(Zeros{Float32}(2)) + @test inv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test inv.(a_zeros) ≡ DiagonalMatrix(Fill(inv(zero(elt)), 2)) + @test pinv.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test pinv.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + @test sqrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test sqrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + if elt <: Real + @test cbrt.(a_ones) ≡ DiagonalMatrix(Ones{elt}(2)) + @test cbrt.(a_zeros) ≡ DiagonalMatrix(Zeros{elt}(2)) + end + @test exp.(a_ones) ≡ DiagonalMatrix(Fill(exp(one(elt)), 2)) + @test exp.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(exp(zero(elt)))}(2)) + @test cis.(a_ones) ≡ DiagonalMatrix(Fill(cis(one(elt)), 2)) + @test cis.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cis(zero(elt)))}(2)) + @test log.(a_ones) ≡ DiagonalMatrix(Zeros{typeof(log(one(elt)))}(2)) + @test log.(a_zeros) ≡ DiagonalMatrix(Fill(log(zero(elt)), 2)) + @test cos.(a_ones) ≡ DiagonalMatrix(Fill(cos(one(elt)), 2)) + @test cos.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cos(zero(elt)))}(2)) + @test sin.(a_ones) ≡ DiagonalMatrix(Fill(sin(one(elt)), 2)) + @test sin.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(sin(zero(elt)))}(2)) + @test tan.(a_ones) ≡ DiagonalMatrix(Fill(tan(one(elt)), 2)) + @test tan.(a_zeros) ≡ DiagonalMatrix(Zeros{typeof(tan(zero(elt)))}(2)) + @test sec.(a_ones) ≡ DiagonalMatrix(Fill(sec(one(elt)), 2)) + @test sec.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(sec(zero(elt)))}(2)) + @test cosh.(a_ones) ≡ DiagonalMatrix(Fill(cosh(one(elt)), 2)) + @test cosh.(a_zeros) ≡ DiagonalMatrix(Ones{typeof(cosh(zero(elt)))}(2)) + end + @testset "Array properties" begin + a = DiagonalMatrix(randn(elt, 2)) + @test !iszero(a) + + a = DiagonalMatrix(zeros(elt, 2)) + @test iszero(a) + + a = DiagonalMatrix(Zeros{elt}(2)) + @test iszero(a) end @testset "LinearAlgebra matrix properties" begin @test ishermitian(DiagonalMatrix([1, 2])) @@ -157,6 +312,68 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric @test !isposdef(DiagonalMatrix([randn(2, 2), randn(3, 3)])) @test !isposdef(DiagonalMatrix([randn(2, 2), randn(2, 3)])) end + @testset "LinearAlgebra matrix functions" begin + diag = randn(elt, 2) + a = DiagonalMatrix(diag) + @test tr(a) ≈ sum(diag) + @test det(a) ≈ prod(diag) + + # Use a positive diagonal in order to take the `log`. + diag = rand(elt, 2) + a = DiagonalMatrix(diag) + @test real(logdet(a)) ≈ real(sum(log, diag)) + @test imag(logdet(a)) ≈ rem2pi(imag(sum(log, diag)), RoundNearest) + + for f in [ + :exp, + :cis, + :log, + :sqrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acot, + :asinh, + :atanh, + :acsch, + :asech, + ] + @eval begin + a = DiagonalMatrix(rand($elt, 2)) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + for f in [:acsc, :asec, :acosh, :acoth] + @eval begin + a = DiagonalMatrix(inv.(rand($elt, 2))) + @test $f(a) ≈ DiagonalMatrix($f.(diagview(a))) + end + end + + if elt <: Real + a = DiagonalMatrix(randn(elt, 2)) + @test cbrt(a) ≈ DiagonalMatrix(cbrt.(diagview(a))) + end + + a = DiagonalMatrix(randn(elt, 2)) + @test inv(a) ≈ DiagonalMatrix(inv.(diagview(a))) + + a = DiagonalMatrix(randn(elt, 2)) + @test pinv(a) ≈ DiagonalMatrix(pinv.(diagview(a))) + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11 @@ -213,6 +430,9 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric @test d isa Diagonal{eltype(v)} @test diagview(d) == diagview(a) @test diagonaltype(a) === typeof(d) + + a = randn(3, 3) + @test getdiagindices(a, 2:3) == diagview(a)[2:3] end @testset "delta" begin for (a, elt′) in (