Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name = "DiagonalArrays"
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.26"
version = "0.3.27"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Expand All @@ -21,11 +21,14 @@ DiagonalArraysNamedDimsArraysExt = "NamedDimsArrays"

[compat]
ArrayLayouts = "1.10.4"
DerivableInterfaces = "0.5.5"
FillArrays = "1.13"
FunctionImplementations = "0.3.1"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.10"
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
NamedDimsArrays = "0.10, 0.11"
SparseArraysBase = "0.7.2"
NamedDimsArrays = "0.12"
SparseArraysBase = "0.8.1"
julia = "1.10"

[workspace]
projects = ["benchmark", "dev", "docs", "examples", "test"]
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

[sources]
DiagonalArrays = {path = ".."}

[compat]
DiagonalArrays = "0.3"
Documenter = "1"
Expand Down
3 changes: 3 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DiagonalArrays = {path = ".."}

[compat]
DiagonalArrays = "0.3"
Test = "1"
1 change: 0 additions & 1 deletion src/DiagonalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ include("diaginterface/diaginterface.jl")
include("diaginterface/diagindex.jl")
include("diaginterface/diagindices.jl")
include("abstractdiagonalarray/abstractdiagonalarray.jl")
include("abstractdiagonalarray/sparsearrayinterface.jl")
include("abstractdiagonalarray/diagonalarraydiaginterface.jl")
include("abstractdiagonalarray/arraylayouts.jl")
include("diagonalarray/diagonalarray.jl")
Expand Down
100 changes: 63 additions & 37 deletions src/abstractdiagonalarray/diagonalarraydiaginterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,59 @@

diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a)}))

using DerivableInterfaces: DerivableInterfaces, @interface
using SparseArraysBase:
SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle
using FunctionImplementations: FunctionImplementations
using SparseArraysBase: SparseArraysBase as SA, AbstractSparseArrayStyle

abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end
abstract type AbstractDiagonalArrayStyle <: AbstractSparseArrayStyle end

struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end
DiagonalArrayInterface{M}(::Val{N}) where {M, N} = DiagonalArrayInterface{N}()
DiagionalArrayInterface(::Val{N}) where {N} = DiagonalArrayInterface{N}()
DiagonalArrayInterface() = DiagonalArrayInterface{Any}()
struct DiagonalArrayStyle <: AbstractDiagonalArrayStyle end
const diag_style = DiagonalArrayStyle()

function Base.similar(::AbstractDiagonalArrayInterface, elt::Type, ax::Tuple)
return similar(DiagonalArray{elt}, ax)
function FunctionImplementations.Style(::Type{<:AbstractDiagonalArray})
return DiagonalArrayStyle()
end
function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any, N}}) where {N}
return DiagonalArrayInterface{N}()
end

abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end

function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle{N}}) where {N}
return DiagonalArrayInterface{N}()
module Broadcast
import SparseArraysBase as SA
abstract type AbstractDiagonalArrayStyle{N} <: SA.Broadcast.AbstractSparseArrayStyle{N} end
struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}()
end

struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end

DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}()

function SparseArraysBase.isstored(
a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N}
) where {N}
return allequal(I)
end
function SparseArraysBase.getstoredindex(
a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N}
using SparseArraysBase: getstoredindex
const getstoredindex_diag = diag_style(getstoredindex)
function getstoredindex_diag(
a::AbstractArray{<:Any, N}, I::Vararg{Int, N}
) where {N}
# TODO: Make this check optional, define `checkstored` like `checkbounds`
# in SparseArraysBase.jl.
# allequal(I) || error("Not a diagonal index.")
return getdiagindex(a, first(I))
end
function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any, 0})
function getstoredindex_diag(a::AbstractArray{<:Any, 0})
return getdiagindex(a, 1)
end
function SparseArraysBase.setstoredindex!(
a::AbstractDiagonalArray{<:Any, N}, value, I::Vararg{Int, N}
function getstoredindex_diag(a::AbstractArray, I::Int...)
return sparse_style(getstoredindex)(a, I...)
end
using SparseArraysBase: setstoredindex!
const setstoredindex!_diag = diag_style(setstoredindex!)
function setstoredindex!_diag(
a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N}
) where {N}
# TODO: Make this check optional, define `checkstored` like `checkbounds`
# in SparseArraysBase.jl.
# allequal(I) || error("Not a diagonal index.")
setdiagindex!(a, value, first(I))
return a
end
function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any, 0}, value)
function setstoredindex!_diag(a::AbstractArray{<:Any, 0}, value)
setdiagindex!(a, value, 1)
return a
end
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray)
using SparseArraysBase: eachstoredindex
const eachstoredindex_diag = diag_style(eachstoredindex)
function eachstoredindex_diag(::IndexCartesian, a::AbstractArray)
return diagindices(a)
end

Expand All @@ -84,8 +79,39 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
return invoke(setindex!, Tuple{AbstractArray, Any, DiagIndex}, a, value, I)
end

@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
return DiagonalArrayStyle{ndims(type)}()
using SparseArraysBase: sparse_style
const getindex_diag = diag_style(getindex)
getindex_diag(a::AbstractArray, I...) = sparse_style(getindex)(a, I...)
const setindex!_diag = diag_style(setindex!)
setindex!_diag(a::AbstractArray, value, I...) = sparse_style(setindex!)(a, value, I...)
const copyto!_diag = diag_style(copyto!)
copyto!_diag(dst::AbstractArray, src::AbstractArray) = sparse_style(copyto!)(dst, src)
const map_diag = diag_style(map)
map_diag(f, as::AbstractArray...) = sparse_style(map)(f, as...)
const map!_diag = diag_style(map!)
map!_diag(f, dst::AbstractArray, as::AbstractArray...) = sparse_style(map!)(f, dst, as...)
const fill!_diag = diag_style(fill!)
fill!_diag(a::AbstractArray, value) = sparse_style(fill!)(a, value)
using FunctionImplementations: zero!
const zero!_diag = diag_style(zero!)
zero!_diag(a::AbstractArray) = sparse_style(zero!)(a)
using SparseArraysBase: isstored
const isstored_diag = diag_style(isstored)
function isstored_diag(
a::AbstractArray{<:Any, N}, I::Vararg{Int, N}
) where {N}
return allequal(I)
end
isstored_diag(a::AbstractArray, I::Int...) = sparse_style(isstored)(a, I...)
using SparseArraysBase: storedvalues
const storedvalues_diag = diag_style(storedvalues)
storedvalues_diag(a::AbstractArray) = diagview(a)
using SparseArraysBase: storedpairs
const storedpairs_diag = diag_style(storedpairs)
storedpairs_diag(a::AbstractArray) = sparse_style(storedpairs)(a)

function Base.Broadcast.BroadcastStyle(type::Type{<:AbstractDiagonalArray})
return Broadcast.DiagonalArrayStyle{ndims(type)}()
end

using Base.Broadcast: Broadcasted, broadcasted
Expand All @@ -99,10 +125,10 @@ function broadcasted_diagview(bc::Broadcasted)
)
return broadcasted(m.f, map(diagview, m.args)...)
end
function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle})
function Base.copy(bc::Broadcasted{<:Broadcast.DiagonalArrayStyle})
return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc))
end
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle})
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:Broadcast.DiagonalArrayStyle})
copyto!(diagview(dest), broadcasted_diagview(bc))
return dest
end
19 changes: 0 additions & 19 deletions src/abstractdiagonalarray/sparsearrayinterface.jl

This file was deleted.

18 changes: 10 additions & 8 deletions src/diagonalarray/diagonalarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct DiagonalArray{T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} <:
end
end

SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
SA.unstored(a::DiagonalArray) = a.unstored
Base.size(a::DiagonalArray) = size(unstored(a))
Base.axes(a::DiagonalArray) = axes(unstored(a))

Expand Down Expand Up @@ -291,7 +291,8 @@ function Base.permutedims(a::DiagonalArray, perm)
return DiagonalArray(copy(diagview(a)), ax_perm)
end

function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
using FunctionImplementations: FunctionImplementations
function FunctionImplementations.permuteddims(a::DiagonalArray, perm)
((ndims(a) == length(perm)) && isperm(perm)) ||
throw(ArgumentError("Not a valid permutation"))
ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a))
Expand All @@ -300,7 +301,6 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
end

# Scalar indexing.
using DerivableInterfaces: @interface, interface
one_based_range(r) = false
one_based_range(r::Base.OneTo) = true
one_based_range(r::Base.Slice) = true
Expand Down Expand Up @@ -335,8 +335,10 @@ function Base.view(a::DiagonalArray, I...)
invoke(view, Tuple{AbstractArray, Vararg}, a, I′...)
end
end
using FunctionImplementations: style
using SparseArraysBase: sparse_style
function Base.getindex(a::DiagonalArray, I::Int...)
return @interface interface(a) a[I...]
return sparse_style(getindex)(a, I...)
end
function Base.getindex(a::DiagonalArray, I::DiagIndex)
return getdiagindex(a, index(I))
Expand All @@ -349,7 +351,7 @@ function Base.getindex(a::DiagonalArray, I...)
I′ = to_indices(a, I)
return if all(i -> i isa Real, I′)
# Catch scalar indexing case.
@interface interface(a) a[I...]
return style(a)(getindex)(a, I...)
elseif all(one_based_range, I′)
_getindex_diag(a, I′...)
else
Expand Down Expand Up @@ -379,7 +381,7 @@ end
# TODO: These definitions work around this issue:
# https://github.com/JuliaArrays/FillArrays.jl/issues/416
# when the diagonal is a FillArrays.Ones or Zeros.
using Base.Broadcast: Broadcast, broadcast, broadcasted
using Base.Broadcast: broadcast, broadcasted
using FillArrays: AbstractFill, Ones, Zeros
_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a)
_broadcasted(::typeof(identity), a::Ones) = a
Expand Down Expand Up @@ -407,8 +409,8 @@ _broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axe
# Eager version of `_broadcasted`.
_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a))

function Broadcast.broadcasted(
::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag}
function Base.Broadcast.broadcasted(
::Broadcast.DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag}
) where {F, T, N, Diag <: AbstractFill{T}}
# TODO: Check that `f` preserves zeros?
return DiagonalArray(_broadcasted(f, diagview(a)), axes(a))
Expand Down
11 changes: 7 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Expand All @@ -14,18 +14,21 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DiagonalArrays = {path = ".."}

[compat]
Adapt = "4.4"
Aqua = "0.8.9"
DerivableInterfaces = "0.5"
DiagonalArrays = "0.3"
FillArrays = "1"
FunctionImplementations = "0.3"
JLArrays = "0.3"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.2.5, 0.3, 0.4, 0.5, 0.6"
NamedDimsArrays = "0.10, 0.11"
NamedDimsArrays = "0.12"
SafeTestsets = "0.1"
SparseArraysBase = "0.7.10"
SparseArraysBase = "0.8"
StableRNGs = "1"
Suppressor = "0.2"
Test = "1"
4 changes: 2 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using DerivableInterfaces: permuteddims
using DiagonalArrays:
DiagonalArrays,
ShapeInitializer,
Expand All @@ -18,6 +17,7 @@ using DiagonalArrays:
diagview,
getdiagindices
using FillArrays: Fill, Ones, Zeros
using FunctionImplementations: permuteddims
using LinearAlgebra:
Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength
Expand Down Expand Up @@ -229,7 +229,7 @@ using Test: @test, @test_throws, @testset, @test_broken, @inferred
@test diagview(b) ≢ diagview(a)
@test size(b) === (4, 2, 3)
end
@testset "DerivableInterfaces.permuteddims" begin
@testset "FunctionImplementations.permuteddims" begin
a = DiagonalArray(randn(elt, 2), (2, 3, 4))
b = permuteddims(a, (3, 1, 2))
@test diagview(b) ≡ diagview(a)
Expand Down
Loading