Skip to content

Commit f6a66da

Browse files
authored
Better DiagonalArrays broadcasting (#39)
1 parent f00267e commit f6a66da

File tree

6 files changed

+81
-17
lines changed

6 files changed

+81
-17
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.14"
4+
version = "0.3.15"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1112
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1213

1314
[compat]
1415
ArrayLayouts = "1.10.4"
1516
DerivableInterfaces = "0.5.5"
1617
FillArrays = "1.13.0"
1718
LinearAlgebra = "1.10.0"
19+
MapBroadcast = "0.1.10"
1820
SparseArraysBase = "0.7.2"
1921
julia = "1.10"

src/abstractdiagonalarray/diagonalarraydiaginterface.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
3030

3131
DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}()
3232

33-
@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
34-
return DiagonalArrayStyle{ndims(type)}()
35-
end
36-
3733
function SparseArraysBase.isstored(
3834
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
3935
) where {N}
@@ -81,6 +77,29 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
8177
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I)
8278
end
8379

80+
@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
81+
return DiagonalArrayStyle{ndims(type)}()
82+
end
83+
84+
using Base.Broadcast: Broadcasted, broadcasted
85+
using MapBroadcast: Mapped
86+
# Map to a flattened broadcast expression of the diagonals of the arrays,
87+
# also checking that the function preserves zeros.
88+
function broadcasted_diagview(bc::Broadcasted)
89+
m = Mapped(bc)
90+
iszero(m.f(map(zero eltype, m.args)...)) || error(
91+
"Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.",
92+
)
93+
return broadcasted(m.f, map(diagview, m.args)...)
94+
end
95+
function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle})
96+
return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc))
97+
end
98+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle})
99+
copyto!(diagview(dest), broadcasted_diagview(bc))
100+
return dest
101+
end
102+
84103
## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))
85104

86105
## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)

src/diagonalarray/delta.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
1-
using FillArrays: Ones, OnesVector
1+
using FillArrays: AbstractFillVector, Ones, OnesVector
2+
3+
const ScaledDelta{T,N,Diag<:AbstractFillVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{
4+
T,N,Diag,Unstored
5+
}
6+
const ScaledDeltaVector{T,Diag<:AbstractFillVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{
7+
T,Diag,Unstored
8+
}
9+
const ScaledDeltaMatrix{T,Diag<:AbstractFillVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{
10+
T,Diag,Unstored
11+
}
12+
13+
const Delta{T,N,Diag<:OnesVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{
14+
T,N,Diag,Unstored
15+
}
16+
const DeltaVector{T,Diag<:OnesVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{
17+
T,Diag,Unstored
18+
}
19+
const DeltaMatrix{T,Diag<:OnesVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{
20+
T,Diag,Unstored
21+
}
222

3-
const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes}
423
function Delta{T}(
524
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
625
) where {T}

src/diagonalarray/diagonalmatrix.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero}
2-
3-
function DiagonalMatrix(diag::AbstractVector)
4-
return DiagonalArray{<:Any,2}(diag)
5-
end
6-
function DiagonalMatrix(diag::AbstractVector, ax::Tuple)
7-
return DiagonalArray{<:Any,2}(diag, ax)
8-
end
1+
const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = DiagonalArray{
2+
T,2,Diag,Unstored
3+
}
94

105
# LinearAlgebra
116

src/diagonalarray/diagonalvector.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
const DiagonalVector{T,Diag,Zero} = DiagonalArray{T,1,Diag,Zero}
1+
const DiagonalVector{T,Diag<:AbstractVector{T},Unstored<:AbstractVector{T}} = DiagonalArray{
2+
T,1,Diag,Unstored
3+
}
24

35
function DiagonalVector(diag::AbstractVector)
46
return DiagonalArray{<:Any,1}(diag)

test/test_basics.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ using DerivableInterfaces: permuteddims
33
using DiagonalArrays:
44
DiagonalArrays,
55
Delta,
6+
DeltaMatrix,
67
DiagonalArray,
78
DiagonalMatrix,
9+
ScaledDelta,
10+
ScaledDeltaMatrix,
811
δ,
912
delta,
1013
diagindices,
@@ -116,6 +119,22 @@ using LinearAlgebra: Diagonal, mul!
116119
@test diagview(b) diagview(a)
117120
@test size(b) === (4, 2, 3)
118121
end
122+
@testset "Broadcasting" begin
123+
a = DiagonalArray(randn(elt, 2), (2, 3))
124+
b = DiagonalArray(randn(elt, 2), (2, 3))
125+
c = a .+ 2 .* b
126+
@test c Array(a) + 2 * Array(b)
127+
# Non-zero-preserving functions not supported yet.
128+
@test_broken a .+ 2
129+
130+
c = DiagonalArray{elt}(undef, (2, 3))
131+
c .= a .+ 2 .* b
132+
@test c Array(a) + 2 * Array(b)
133+
134+
# Non-zero-preserving functions not supported yet.
135+
c = DiagonalArray{elt}(undef, (2, 3))
136+
@test_broken c .= a .+ 2
137+
end
119138
@testset "Matrix multiplication" begin
120139
a1 = DiagonalArray{elt}(undef, (2, 3))
121140
a1[1, 1] = 11
@@ -197,7 +216,9 @@ using LinearAlgebra: Diagonal, mul!
197216
@test eltype(a) === elt′
198217
@test diaglength(a) == 2
199218
@test a isa DiagonalArray{elt′,2}
219+
@test a isa DiagonalMatrix{elt′}
200220
@test a isa Delta{elt′,2}
221+
@test a isa DeltaMatrix{elt′}
201222
@test size(a) == (2, 2)
202223
@test diaglength(a) == 2
203224
@test storedlength(a) == 2
@@ -211,11 +232,17 @@ using LinearAlgebra: Diagonal, mul!
211232
# TODO: Fix this. Mapping doesn't preserve
212233
# the diagonal structure properly.
213234
# https://github.com/ITensor/DiagonalArrays.jl/issues/7
214-
@test_broken diagview(a′) isa Fill
235+
@test diagview(a′) isa Fill{promote_type(Int, elt′)}
236+
@test a′ isa ScaledDelta{promote_type(Int, elt′),2}
237+
@test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)}
215238

216239
b = randn(elt, (2, 3))
217240
a_dest = a * b
218241
@test a_dest Array(a) * Array(b)
242+
243+
a_dest = a * a
244+
@test a_dest Array(a) * Array(a)
245+
@test diagview(a_dest) isa Ones{elt′}
219246
end
220247
end
221248
end

0 commit comments

Comments
 (0)