Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.14"
version = "0.3.15"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"

[compat]
ArrayLayouts = "1.10.4"
DerivableInterfaces = "0.5.5"
FillArrays = "1.13.0"
LinearAlgebra = "1.10.0"
MapBroadcast = "0.1.10"
SparseArraysBase = "0.7.2"
julia = "1.10"
27 changes: 23 additions & 4 deletions src/abstractdiagonalarray/diagonalarraydiaginterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end

DiagonalArrayStyle{M}(::Val{N}) where {M,N} = DiagonalArrayStyle{N}()

@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
return DiagonalArrayStyle{ndims(type)}()
end

function SparseArraysBase.isstored(
a::AbstractDiagonalArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
Expand Down Expand Up @@ -81,6 +77,29 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
return invoke(setindex!, Tuple{AbstractArray,Any,DiagIndex}, a, value, I)
end

@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
return DiagonalArrayStyle{ndims(type)}()
end

using Base.Broadcast: Broadcasted, broadcasted
using MapBroadcast: Mapped
# Map to a flattened broadcast expression of the diagonals of the arrays,
# also checking that the function preserves zeros.
function broadcasted_diagview(bc::Broadcasted)
m = Mapped(bc)
iszero(m.f(map(zero ∘ eltype, m.args)...)) || error(
"Broadcasting DiagonalArrays with function that doesn't preserve zeros isn't supported yet.",
)
return broadcasted(m.f, map(diagview, m.args)...)
end
function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle})
return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc))
end
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle})
copyto!(diagview(dest), broadcasted_diagview(bc))
return dest
end

## SparseArraysBase.StorageIndex(i::DiagIndex) = StorageIndex(index(i))

## function Base.getindex(a::AbstractDiagonalArray, i::DiagIndex)
Expand Down
23 changes: 21 additions & 2 deletions src/diagonalarray/delta.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
using FillArrays: Ones, OnesVector
using FillArrays: AbstractFillVector, Ones, OnesVector

const ScaledDelta{T,N,Diag<:AbstractFillVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{
T,N,Diag,Unstored
}
const ScaledDeltaVector{T,Diag<:AbstractFillVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{
T,Diag,Unstored
}
const ScaledDeltaMatrix{T,Diag<:AbstractFillVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{
T,Diag,Unstored
}

const Delta{T,N,Diag<:OnesVector{T},Unstored<:AbstractArray{T,N}} = DiagonalArray{
T,N,Diag,Unstored
}
const DeltaVector{T,Diag<:OnesVector{T},Unstored<:AbstractVector{T}} = DiagonalVector{
T,Diag,Unstored
}
const DeltaMatrix{T,Diag<:OnesVector{T},Unstored<:AbstractMatrix{T}} = DiagonalMatrix{
T,Diag,Unstored
}

const Delta{T,N,V<:OnesVector{T},Axes} = DiagonalArray{T,N,V,Axes}
function Delta{T}(
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
) where {T}
Expand Down
11 changes: 3 additions & 8 deletions src/diagonalarray/diagonalmatrix.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero}

function DiagonalMatrix(diag::AbstractVector)
return DiagonalArray{<:Any,2}(diag)
end
function DiagonalMatrix(diag::AbstractVector, ax::Tuple)
return DiagonalArray{<:Any,2}(diag, ax)
end
const DiagonalMatrix{T,Diag<:AbstractVector{T},Unstored<:AbstractMatrix{T}} = DiagonalArray{
T,2,Diag,Unstored
}

# LinearAlgebra

Expand Down
4 changes: 3 additions & 1 deletion src/diagonalarray/diagonalvector.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
const DiagonalVector{T,Diag,Zero} = DiagonalArray{T,1,Diag,Zero}
const DiagonalVector{T,Diag<:AbstractVector{T},Unstored<:AbstractVector{T}} = DiagonalArray{
T,1,Diag,Unstored
}

function DiagonalVector(diag::AbstractVector)
return DiagonalArray{<:Any,1}(diag)
Expand Down
29 changes: 28 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ using DerivableInterfaces: permuteddims
using DiagonalArrays:
DiagonalArrays,
Delta,
DeltaMatrix,
DiagonalArray,
DiagonalMatrix,
ScaledDelta,
ScaledDeltaMatrix,
δ,
delta,
diagindices,
Expand Down Expand Up @@ -116,6 +119,22 @@ using LinearAlgebra: Diagonal, mul!
@test diagview(b) ≡ diagview(a)
@test size(b) === (4, 2, 3)
end
@testset "Broadcasting" begin
a = DiagonalArray(randn(elt, 2), (2, 3))
b = DiagonalArray(randn(elt, 2), (2, 3))
c = a .+ 2 .* b
@test c ≈ Array(a) + 2 * Array(b)
# Non-zero-preserving functions not supported yet.
@test_broken a .+ 2

c = DiagonalArray{elt}(undef, (2, 3))
c .= a .+ 2 .* b
@test c ≈ Array(a) + 2 * Array(b)

# Non-zero-preserving functions not supported yet.
c = DiagonalArray{elt}(undef, (2, 3))
@test_broken c .= a .+ 2
end
@testset "Matrix multiplication" begin
a1 = DiagonalArray{elt}(undef, (2, 3))
a1[1, 1] = 11
Expand Down Expand Up @@ -197,7 +216,9 @@ using LinearAlgebra: Diagonal, mul!
@test eltype(a) === elt′
@test diaglength(a) == 2
@test a isa DiagonalArray{elt′,2}
@test a isa DiagonalMatrix{elt′}
@test a isa Delta{elt′,2}
@test a isa DeltaMatrix{elt′}
@test size(a) == (2, 2)
@test diaglength(a) == 2
@test storedlength(a) == 2
Expand All @@ -211,11 +232,17 @@ using LinearAlgebra: Diagonal, mul!
# TODO: Fix this. Mapping doesn't preserve
# the diagonal structure properly.
# https://github.com/ITensor/DiagonalArrays.jl/issues/7
@test_broken diagview(a′) isa Fill
@test diagview(a′) isa Fill{promote_type(Int, elt′)}
@test a′ isa ScaledDelta{promote_type(Int, elt′),2}
@test a′ isa ScaledDeltaMatrix{promote_type(Int, elt′)}

b = randn(elt, (2, 3))
a_dest = a * b
@test a_dest ≈ Array(a) * Array(b)

a_dest = a * a
@test a_dest ≈ Array(a) * Array(a)
@test diagview(a_dest) isa Ones{elt′}
end
end
end
Expand Down
Loading