From 453f69467a211afb61c962f53942099e421807b4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 20 Aug 2025 16:25:49 -0400 Subject: [PATCH 1/4] Better DiagonalArrays broadcasting --- Project.toml | 4 ++- .../diagonalarraydiaginterface.jl | 27 ++++++++++++++++--- test/test_basics.jl | 22 ++++++++++++++- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 396e917..39688f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,14 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.14" +version = "0.3.15" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" [compat] @@ -15,5 +16,6 @@ ArrayLayouts = "1.10.4" DerivableInterfaces = "0.5.5" FillArrays = "1.13.0" LinearAlgebra = "1.10.0" +MapBroadcast = "0.1.10" SparseArraysBase = "0.7.2" julia = "1.10" diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index 7b2d389..cfdbbdf 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -30,10 +30,6 @@ struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}() -@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) - return DiagonalArrayStyle{ndims(type)}() -end - function SparseArraysBase.isstored( a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N} ) where {N} @@ -81,6 +77,29 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex) return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I) end +@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type) + return DiagonalArrayStyle{ndims(type)}() +end + +using Base.Broadcast: Broadcasted, broadcasted, flatten +using MapBroadcast: Mapped +# Map to a flattened broadcast expression of the diagonals of the arrays, +# also checking that the function preserves zeros. +function broadcasted_diagview(bc::Broadcasted) + m = Mapped(bc) + iszero(m.f(map(zero ∘ eltype, m.args)...)) || error( + "Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.", + ) + return broadcasted(m.f, map(diagview, m.args)...) +end +function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle}) + return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc)) +end +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle}) + copyto!(diagview(dest), broadcasted_diagview(bc)) + return dest +end + ## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i)) ## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex) diff --git a/test/test_basics.jl b/test/test_basics.jl index a7f1c00..2eb2772 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -116,6 +116,22 @@ using LinearAlgebra: Diagonal, mul! @test diagview(b) ≡ diagview(a) @test size(b) === (4, 2, 3) end + @testset "Broadcasting" begin + a = DiagonalArray(randn(elt, 2), (2, 3)) + b = DiagonalArray(randn(elt, 2), (2, 3)) + c = a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + # Non-zero-preserving functions not supported yet. + @test_broken a .+ 2 + + c = DiagonalArray{elt}(undef, (2, 3)) + c .= a .+ 2 .* b + @test c ≈ Array(a) + 2 * Array(b) + + # Non-zero-preserving functions not supported yet. + c = DiagonalArray{elt}(undef, (2, 3)) + @test_broken c .= a .+ 2 + end @testset "Matrix multiplication" begin a1 = DiagonalArray{elt}(undef, (2, 3)) a1[1, 1] = 11 @@ -211,11 +227,15 @@ using LinearAlgebra: Diagonal, mul! # TODO: Fix this. Mapping doesn't preserve # the diagonal structure properly. # https://github.com/ITensor/DiagonalArrays.jl/issues/7 - @test_broken diagview(a′) isa Fill + @test diagview(a′) isa Fill{promote_type(Int, elt′)} b = randn(elt, (2, 3)) a_dest = a * b @test a_dest ≈ Array(a) * Array(b) + + a_dest = a * a + @test a_dest ≈ Array(a) * Array(a) + @test diagview(a_dest) isa Ones{elt′} end end end From 622f142740405e71fbec9ac37acb0cc9a5d298ad Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 20 Aug 2025 16:26:54 -0400 Subject: [PATCH 2/4] Remove stale import --- src/abstractdiagonalarray/diagonalarraydiaginterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl index cfdbbdf..3f914d3 100644 --- a/src/abstractdiagonalarray/diagonalarraydiaginterface.jl +++ b/src/abstractdiagonalarray/diagonalarraydiaginterface.jl @@ -81,7 +81,7 @@ end return DiagonalArrayStyle{ndims(type)}() end -using Base.Broadcast: Broadcasted, broadcasted, flatten +using Base.Broadcast: Broadcasted, broadcasted using MapBroadcast: Mapped # Map to a flattened broadcast expression of the diagonals of the arrays, # also checking that the function preserves zeros. From 5de388d2f222405820ff0348b7254aa12ea94ed4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 20 Aug 2025 16:32:49 -0400 Subject: [PATCH 3/4] Introduce ScaledDelt --- src/diagonalarray/delta.jl | 4 +++- test/test_basics.jl | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diagonalarray/delta.jl b/src/diagonalarray/delta.jl index b2dc860..ade97ea 100644 --- a/src/diagonalarray/delta.jl +++ b/src/diagonalarray/delta.jl @@ -1,4 +1,6 @@ -using FillArrays: Ones, OnesVector +using FillArrays: AbstractFillVector, Ones, OnesVector + +const ScaledDelta{T,N,V<:AbstractFillVector{T},Axes} = DiagonalArray{T,N,V,Axes} const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes} function Delta{T}( diff --git a/test/test_basics.jl b/test/test_basics.jl index 2eb2772..f3b8774 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,6 +5,7 @@ using DiagonalArrays: Delta, DiagonalArray, DiagonalMatrix, + ScaledDelta, δ, delta, diagindices, @@ -228,6 +229,7 @@ using LinearAlgebra: Diagonal, mul! # the diagonal structure properly. # https://github.com/ITensor/DiagonalArrays.jl/issues/7 @test diagview(a′) isa Fill{promote_type(Int, elt′)} + @test a′ isa ScaledDelta{promote_type(Int, elt′)} b = randn(elt, (2, 3)) a_dest = a * b From 73945130e0bb7144f426ea2b1b7ab1106a5337f3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 20 Aug 2025 16:43:40 -0400 Subject: [PATCH 4/4] Better type aliases --- src/diagonalarray/delta.jl | 21 +++++++++++++++++++-- src/diagonalarray/diagonalmatrix.jl | 11 +++-------- src/diagonalarray/diagonalvector.jl | 4 +++- test/test_basics.jl | 7 ++++++- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/diagonalarray/delta.jl b/src/diagonalarray/delta.jl index ade97ea..f91e312 100644 --- a/src/diagonalarray/delta.jl +++ b/src/diagonalarray/delta.jl @@ -1,8 +1,25 @@ using FillArrays: AbstractFillVector, Ones, OnesVector -const ScaledDelta{T,N,V<:AbstractFillVector{T},Axes} = DiagonalArray{T,N,V,Axes} +const ScaledDelta{T,N,Diag<:AbstractFillVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ + T,N,Diag,Unstored +} +const ScaledDeltaVector{T,Diag<:AbstractFillVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ + T,Diag,Unstored +} +const ScaledDeltaMatrix{T,Diag<:AbstractFillVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ + T,Diag,Unstored +} + +const Delta{T,N,Diag<:OnesVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{ + T,N,Diag,Unstored +} +const DeltaVector{T,Diag<:OnesVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{ + T,Diag,Unstored +} +const DeltaMatrix{T,Diag<:OnesVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{ + T,Diag,Unstored +} -const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes} function Delta{T}( ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}} ) where {T} diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 9c77c25..2fff225 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -1,11 +1,6 @@ -const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero} - -function DiagonalMatrix(diag::AbstractVector) - return DiagonalArray{<:Any,2}(diag) -end -function DiagonalMatrix(diag::AbstractVector, ax::Tuple) - return DiagonalArray{<:Any,2}(diag, ax) -end +const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = DiagonalArray{ + T,2,Diag,Unstored +} # LinearAlgebra diff --git a/src/diagonalarray/diagonalvector.jl b/src/diagonalarray/diagonalvector.jl index 40e35e4..ec3cde8 100644 --- a/src/diagonalarray/diagonalvector.jl +++ b/src/diagonalarray/diagonalvector.jl @@ -1,4 +1,6 @@ -const DiagonalVector{T,Diag,Zero} = DiagonalArray{T,1,Diag,Zero} +const DiagonalVector{T,Diag<:AbstractVector{T},Unstored<:AbstractVector{T}} = DiagonalArray{ + T,1,Diag,Unstored +} function DiagonalVector(diag::AbstractVector) return DiagonalArray{<:Any,1}(diag) diff --git a/test/test_basics.jl b/test/test_basics.jl index f3b8774..dbc5f16 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -3,9 +3,11 @@ using DerivableInterfaces: permuteddims using DiagonalArrays: DiagonalArrays, Delta, + DeltaMatrix, DiagonalArray, DiagonalMatrix, ScaledDelta, + ScaledDeltaMatrix, δ, delta, diagindices, @@ -214,7 +216,9 @@ using LinearAlgebra: Diagonal, mul! @test eltype(a) === elt′ @test diaglength(a) == 2 @test a isa DiagonalArray{elt′,2} + @test a isa DiagonalMatrix{elt′} @test a isa Delta{elt′,2} + @test a isa DeltaMatrix{elt′} @test size(a) == (2, 2) @test diaglength(a) == 2 @test storedlength(a) == 2 @@ -229,7 +233,8 @@ using LinearAlgebra: Diagonal, mul! # the diagonal structure properly. # https://github.com/ITensor/DiagonalArrays.jl/issues/7 @test diagview(a′) isa Fill{promote_type(Int, elt′)} - @test a′ isa ScaledDelta{promote_type(Int, elt′)} + @test a′ isa ScaledDelta{promote_type(Int, elt′),2} + @test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)} b = randn(elt, (2, 3)) a_dest = a * b