diff --git a/Project.toml b/Project.toml index 396e917..39688f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.14" +version = "0.3.15" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" [compat] @@ -15,5 +16,6 @@ ArrayLayouts = "1.10.4" DerivableInterfaces = "0.5.5" FillArrays = "1.13.0" LinearAlgebra = "1.10.0" +MapBroadcast = "0.1.10" SparseArraysBase = "0.7.2" julia = "1.10" diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 7b2d389..3f914d3 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -30,10 +30,6 @@ struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}() -@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) - return DiagonalArrayStyle{ndims(type)}() -end - function SparseArraysBase.isstored( a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N} ) where {N} @@ -81,6 +77,29 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex) return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I) end +@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) + return DiagonalArrayStyle{ndims(type)}() +end + +using Base.Broadcast: Broadcasted, broadcasted +using MapBroadcast: Mapped +# Map to a flattened broadcast expression of the diagonals of the arrays, +# also checking that the function preserves zeros. +function broadcasted_diagview(bc::Broadcasted) + m = Mapped(bc) + iszero(m.f(map(zero ∘ eltype, m.args)...)) || error( + "Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.", + ) + return broadcasted(m.f, map(diagview, m.args)...) +end +function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle}) + return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc)) +end +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}) + copyto!(diagview(dest), broadcasted_diagview(bc)) + return dest +end + ## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i)) ## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex) diff --git a/src/diagonalarray/delta.jl b/src/diagonalarray/delta.jl index b2dc860..f91e312 100644 --- a/src/diagonalarray/delta.jl +++ b/src/diagonalarray/delta.jl @@ -1,6 +1,25 @@ -using FillArrays: Ones, OnesVector +using FillArrays: AbstractFillVector, Ones, OnesVector + +const ScaledDelta{T,N,Diag<:AbstractFillVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ + T,N,Diag,Unstored +} +const ScaledDeltaVector{T,Diag<:AbstractFillVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ + T,Diag,Unstored +} +const ScaledDeltaMatrix{T,Diag<:AbstractFillVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ + T,Diag,Unstored +} + +const Delta{T,N,Diag<:OnesVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ + T,N,Diag,Unstored +} +const DeltaVector{T,Diag<:OnesVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ + T,Diag,Unstored +} +const DeltaMatrix{T,Diag<:OnesVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ + T,Diag,Unstored +} -const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes} function Delta{T}( ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ) where {T} diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 9c77c25..2fff225 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -1,11 +1,6 @@ -const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero} - -function DiagonalMatrix(diag::AbstractVector) - return DiagonalArray{<:Any,2}(diag) -end -function DiagonalMatrix(diag::AbstractVector, ax::Tuple) - return DiagonalArray{<:Any,2}(diag, ax) -end +const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = DiagonalArray{ + T,2,Diag,Unstored +} # LinearAlgebra diff --git a/src/diagonalarray/diagonalvector.jl b/src/diagonalarray/diagonalvector.jl index 40e35e4..ec3cde8 100644 --- a/src/diagonalarray/diagonalvector.jl +++ b/src/diagonalarray/diagonalvector.jl @@ -1,4 +1,6 @@ -const DiagonalVector{T,Diag,Zero} = DiagonalArray{T,1,Diag,Zero} +const DiagonalVector{T,Diag<:AbstractVector{T},Unstored<:AbstractVector{T}} = DiagonalArray{ + T,1,Diag,Unstored +} function DiagonalVector(diag::AbstractVector) return DiagonalArray{<:Any,1}(diag) diff --git a/test/test_basics.jl b/test/test_basics.jl index a7f1c00..dbc5f16 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -3,8 +3,11 @@ using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, Delta, + DeltaMatrix, DiagonalArray, DiagonalMatrix, + ScaledDelta, + ScaledDeltaMatrix, δ, delta, diagindices, @@ -116,6 +119,22 @@ using LinearAlgebra: Diagonal, mul! @test diagview(b) ≡ diagview(a) @test size(b) === (4, 2, 3) end + @testset "Broadcasting" begin + a = DiagonalArray(randn(elt, 2), (2, 3)) + b = DiagonalArray(randn(elt, 2), (2, 3)) + c = a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + # Non-zero-preserving functions not supported yet. + @test_broken a .+ 2 + + c = DiagonalArray{elt}(undef, (2, 3)) + c .= a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + + # Non-zero-preserving functions not supported yet. + c = DiagonalArray{elt}(undef, (2, 3)) + @test_broken c .= a .+ 2 + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11 @@ -197,7 +216,9 @@ using LinearAlgebra: Diagonal, mul! @test eltype(a) === elt′ @test diaglength(a) == 2 @test a isa DiagonalArray{elt′,2} + @test a isa DiagonalMatrix{elt′} @test a isa Delta{elt′,2} + @test a isa DeltaMatrix{elt′} @test size(a) == (2, 2) @test diaglength(a) == 2 @test storedlength(a) == 2 @@ -211,11 +232,17 @@ using LinearAlgebra: Diagonal, mul! # TODO: Fix this. Mapping doesn't preserve # the diagonal structure properly. # https://github.com/ITensor/DiagonalArrays.jl/issues/7 - @test_broken diagview(a′) isa Fill + @test diagview(a′) isa Fill{promote_type(Int, elt′)} + @test a′ isa ScaledDelta{promote_type(Int, elt′),2} + @test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)} b = randn(elt, (2, 3)) a_dest = a * b @test a_dest ≈ Array(a) * Array(b) + + a_dest = a * a + @test a_dest ≈ Array(a) * Array(a) + @test diagview(a_dest) isa Ones{elt′} end end end