Skip to content

Commit a00a9f3

Browse files
authored
Switch to FunctionImplementations (#59)
1 parent bd2d640 commit a00a9f3

File tree

9 files changed

+96
-76
lines changed

9 files changed

+96
-76
lines changed

Project.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.26"
4+
version = "0.3.27"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
8-
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
98
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
9+
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
1212
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
@@ -21,11 +21,14 @@ DiagonalArraysNamedDimsArraysExt = "NamedDimsArrays"
2121

2222
[compat]
2323
ArrayLayouts = "1.10.4"
24-
DerivableInterfaces = "0.5.5"
2524
FillArrays = "1.13"
25+
FunctionImplementations = "0.3.1"
2626
LinearAlgebra = "1.10"
2727
MapBroadcast = "0.1.10"
2828
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
29-
NamedDimsArrays = "0.10, 0.11"
30-
SparseArraysBase = "0.7.2"
29+
NamedDimsArrays = "0.12"
30+
SparseArraysBase = "0.8.1"
3131
julia = "1.10"
32+
33+
[workspace]
34+
projects = ["benchmark", "dev", "docs", "examples", "test"]

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
55

6+
[sources]
7+
DiagonalArrays = {path = ".."}
8+
69
[compat]
710
DiagonalArrays = "0.3"
811
Documenter = "1"

examples/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
44

5+
[sources]
6+
DiagonalArrays = {path = ".."}
7+
58
[compat]
69
DiagonalArrays = "0.3"
710
Test = "1"

src/DiagonalArrays.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ include("diaginterface/diaginterface.jl")
55
include("diaginterface/diagindex.jl")
66
include("diaginterface/diagindices.jl")
77
include("abstractdiagonalarray/abstractdiagonalarray.jl")
8-
include("abstractdiagonalarray/sparsearrayinterface.jl")
98
include("abstractdiagonalarray/diagonalarraydiaginterface.jl")
109
include("abstractdiagonalarray/arraylayouts.jl")
1110
include("diagonalarray/diagonalarray.jl")

src/abstractdiagonalarray/diagonalarraydiaginterface.jl

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,64 +2,59 @@
22

33
diagview(a::AbstractDiagonalArray) = throw(MethodError(diagview, Tuple{typeof(a)}))
44

5-
using DerivableInterfaces: DerivableInterfaces, @interface
6-
using SparseArraysBase:
7-
SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle
5+
using FunctionImplementations: FunctionImplementations
6+
using SparseArraysBase: SparseArraysBase as SA, AbstractSparseArrayStyle
87

9-
abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end
8+
abstract type AbstractDiagonalArrayStyle <: AbstractSparseArrayStyle end
109

11-
struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end
12-
DiagonalArrayInterface{M}(::Val{N}) where {M, N} = DiagonalArrayInterface{N}()
13-
DiagionalArrayInterface(::Val{N}) where {N} = DiagonalArrayInterface{N}()
14-
DiagonalArrayInterface() = DiagonalArrayInterface{Any}()
10+
struct DiagonalArrayStyle <: AbstractDiagonalArrayStyle end
11+
const diag_style = DiagonalArrayStyle()
1512

16-
function Base.similar(::AbstractDiagonalArrayInterface, elt::Type, ax::Tuple)
17-
return similar(DiagonalArray{elt}, ax)
13+
function FunctionImplementations.Style(::Type{<:AbstractDiagonalArray})
14+
return DiagonalArrayStyle()
1815
end
19-
function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArray{<:Any, N}}) where {N}
20-
return DiagonalArrayInterface{N}()
21-
end
22-
23-
abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end
2416

25-
function DerivableInterfaces.interface(::Type{<:AbstractDiagonalArrayStyle{N}}) where {N}
26-
return DiagonalArrayInterface{N}()
17+
module Broadcast
18+
import SparseArraysBase as SA
19+
abstract type AbstractDiagonalArrayStyle{N} <: SA.Broadcast.AbstractSparseArrayStyle{N} end
20+
struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
21+
DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}()
2722
end
2823

29-
struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
30-
31-
DiagonalArrayStyle{M}(::Val{N}) where {M, N} = DiagonalArrayStyle{N}()
32-
33-
function SparseArraysBase.isstored(
34-
a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N}
35-
) where {N}
36-
return allequal(I)
37-
end
38-
function SparseArraysBase.getstoredindex(
39-
a::AbstractDiagonalArray{<:Any, N}, I::Vararg{Int, N}
24+
using SparseArraysBase: getstoredindex
25+
const getstoredindex_diag = diag_style(getstoredindex)
26+
function getstoredindex_diag(
27+
a::AbstractArray{<:Any, N}, I::Vararg{Int, N}
4028
) where {N}
4129
# TODO: Make this check optional, define `checkstored` like `checkbounds`
4230
# in SparseArraysBase.jl.
4331
# allequal(I) || error("Not a diagonal index.")
4432
return getdiagindex(a, first(I))
4533
end
46-
function SparseArraysBase.getstoredindex(a::AbstractDiagonalArray{<:Any, 0})
34+
function getstoredindex_diag(a::AbstractArray{<:Any, 0})
4735
return getdiagindex(a, 1)
4836
end
49-
function SparseArraysBase.setstoredindex!(
50-
a::AbstractDiagonalArray{<:Any, N}, value, I::Vararg{Int, N}
37+
function getstoredindex_diag(a::AbstractArray, I::Int...)
38+
return sparse_style(getstoredindex)(a, I...)
39+
end
40+
using SparseArraysBase: setstoredindex!
41+
const setstoredindex!_diag = diag_style(setstoredindex!)
42+
function setstoredindex!_diag(
43+
a::AbstractArray{<:Any, N}, value, I::Vararg{Int, N}
5144
) where {N}
5245
# TODO: Make this check optional, define `checkstored` like `checkbounds`
5346
# in SparseArraysBase.jl.
5447
# allequal(I) || error("Not a diagonal index.")
5548
setdiagindex!(a, value, first(I))
5649
return a
5750
end
58-
function SparseArraysBase.setstoredindex!(a::AbstractDiagonalArray{<:Any, 0}, value)
51+
function setstoredindex!_diag(a::AbstractArray{<:Any, 0}, value)
5952
setdiagindex!(a, value, 1)
6053
return a
6154
end
62-
function SparseArraysBase.eachstoredindex(::IndexCartesian, a::AbstractDiagonalArray)
55+
using SparseArraysBase: eachstoredindex
56+
const eachstoredindex_diag = diag_style(eachstoredindex)
57+
function eachstoredindex_diag(::IndexCartesian, a::AbstractArray)
6358
return diagindices(a)
6459
end
6560

@@ -84,8 +79,39 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
8479
return invoke(setindex!, Tuple{AbstractArray, Any, DiagIndex}, a, value, I)
8580
end
8681

87-
@interface ::AbstractDiagonalArrayInterface function Broadcast.BroadcastStyle(type::Type)
88-
return DiagonalArrayStyle{ndims(type)}()
82+
using SparseArraysBase: sparse_style
83+
const getindex_diag = diag_style(getindex)
84+
getindex_diag(a::AbstractArray, I...) = sparse_style(getindex)(a, I...)
85+
const setindex!_diag = diag_style(setindex!)
86+
setindex!_diag(a::AbstractArray, value, I...) = sparse_style(setindex!)(a, value, I...)
87+
const copyto!_diag = diag_style(copyto!)
88+
copyto!_diag(dst::AbstractArray, src::AbstractArray) = sparse_style(copyto!)(dst, src)
89+
const map_diag = diag_style(map)
90+
map_diag(f, as::AbstractArray...) = sparse_style(map)(f, as...)
91+
const map!_diag = diag_style(map!)
92+
map!_diag(f, dst::AbstractArray, as::AbstractArray...) = sparse_style(map!)(f, dst, as...)
93+
const fill!_diag = diag_style(fill!)
94+
fill!_diag(a::AbstractArray, value) = sparse_style(fill!)(a, value)
95+
using FunctionImplementations: zero!
96+
const zero!_diag = diag_style(zero!)
97+
zero!_diag(a::AbstractArray) = sparse_style(zero!)(a)
98+
using SparseArraysBase: isstored
99+
const isstored_diag = diag_style(isstored)
100+
function isstored_diag(
101+
a::AbstractArray{<:Any, N}, I::Vararg{Int, N}
102+
) where {N}
103+
return allequal(I)
104+
end
105+
isstored_diag(a::AbstractArray, I::Int...) = sparse_style(isstored)(a, I...)
106+
using SparseArraysBase: storedvalues
107+
const storedvalues_diag = diag_style(storedvalues)
108+
storedvalues_diag(a::AbstractArray) = diagview(a)
109+
using SparseArraysBase: storedpairs
110+
const storedpairs_diag = diag_style(storedpairs)
111+
storedpairs_diag(a::AbstractArray) = sparse_style(storedpairs)(a)
112+
113+
function Base.Broadcast.BroadcastStyle(type::Type{<:AbstractDiagonalArray})
114+
return Broadcast.DiagonalArrayStyle{ndims(type)}()
89115
end
90116

91117
using Base.Broadcast: Broadcasted, broadcasted
@@ -99,10 +125,10 @@ function broadcasted_diagview(bc::Broadcasted)
99125
)
100126
return broadcasted(m.f, map(diagview, m.args)...)
101127
end
102-
function Base.copy(bc::Broadcasted{<:DiagonalArrayStyle})
128+
function Base.copy(bc::Broadcasted{<:Broadcast.DiagonalArrayStyle})
103129
return DiagonalArray(copy(broadcasted_diagview(bc)), axes(bc))
104130
end
105-
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:DiagonalArrayStyle})
131+
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:Broadcast.DiagonalArrayStyle})
106132
copyto!(diagview(dest), broadcasted_diagview(bc))
107133
return dest
108134
end

src/abstractdiagonalarray/sparsearrayinterface.jl

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/diagonalarray/diagonalarray.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct DiagonalArray{T, N, D <: AbstractVector{T}, U <: AbstractArray{T, N}} <:
2222
end
2323
end
2424

25-
SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
25+
SA.unstored(a::DiagonalArray) = a.unstored
2626
Base.size(a::DiagonalArray) = size(unstored(a))
2727
Base.axes(a::DiagonalArray) = axes(unstored(a))
2828

@@ -291,7 +291,8 @@ function Base.permutedims(a::DiagonalArray, perm)
291291
return DiagonalArray(copy(diagview(a)), ax_perm)
292292
end
293293

294-
function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
294+
using FunctionImplementations: FunctionImplementations
295+
function FunctionImplementations.permuteddims(a::DiagonalArray, perm)
295296
((ndims(a) == length(perm)) && isperm(perm)) ||
296297
throw(ArgumentError("Not a valid permutation"))
297298
ax_perm = ntuple(d -> axes(a)[perm[d]], ndims(a))
@@ -300,7 +301,6 @@ function DerivableInterfaces.permuteddims(a::DiagonalArray, perm)
300301
end
301302

302303
# Scalar indexing.
303-
using DerivableInterfaces: @interface, interface
304304
one_based_range(r) = false
305305
one_based_range(r::Base.OneTo) = true
306306
one_based_range(r::Base.Slice) = true
@@ -335,8 +335,10 @@ function Base.view(a::DiagonalArray, I...)
335335
invoke(view, Tuple{AbstractArray, Vararg}, a, I′...)
336336
end
337337
end
338+
using FunctionImplementations: style
339+
using SparseArraysBase: sparse_style
338340
function Base.getindex(a::DiagonalArray, I::Int...)
339-
return @interface interface(a) a[I...]
341+
return sparse_style(getindex)(a, I...)
340342
end
341343
function Base.getindex(a::DiagonalArray, I::DiagIndex)
342344
return getdiagindex(a, index(I))
@@ -349,7 +351,7 @@ function Base.getindex(a::DiagonalArray, I...)
349351
I′ = to_indices(a, I)
350352
return if all(i -> i isa Real, I′)
351353
# Catch scalar indexing case.
352-
@interface interface(a) a[I...]
354+
return style(a)(getindex)(a, I...)
353355
elseif all(one_based_range, I′)
354356
_getindex_diag(a, I′...)
355357
else
@@ -379,7 +381,7 @@ end
379381
# TODO: These definitions work around this issue:
380382
# https://github.com/JuliaArrays/FillArrays.jl/issues/416
381383
# when the diagonal is a FillArrays.Ones or Zeros.
382-
using Base.Broadcast: Broadcast, broadcast, broadcasted
384+
using Base.Broadcast: broadcast, broadcasted
383385
using FillArrays: AbstractFill, Ones, Zeros
384386
_broadcasted(f::F, a::AbstractArray) where {F} = broadcasted(f, a)
385387
_broadcasted(::typeof(identity), a::Ones) = a
@@ -407,8 +409,8 @@ _broadcasted(::typeof(cosh), a::Zeros) = Ones{typeof(cosh(zero(eltype(a))))}(axe
407409
# Eager version of `_broadcasted`.
408410
_broadcast(f::F, a::AbstractArray) where {F} = copy(_broadcasted(f, a))
409411

410-
function Broadcast.broadcasted(
411-
::DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag}
412+
function Base.Broadcast.broadcasted(
413+
::Broadcast.DiagonalArrayStyle{N}, f::F, a::DiagonalArray{T, N, Diag}
412414
) where {F, T, N, Diag <: AbstractFill{T}}
413415
# TODO: Check that `f` preserves zeros?
414416
return DiagonalArray(_broadcasted(f, diagview(a)), axes(a))

test/Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4-
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
54
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
65
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
6+
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
77
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
@@ -14,18 +14,21 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1414
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1515
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1616

17+
[sources]
18+
DiagonalArrays = {path = ".."}
19+
1720
[compat]
1821
Adapt = "4.4"
1922
Aqua = "0.8.9"
20-
DerivableInterfaces = "0.5"
2123
DiagonalArrays = "0.3"
2224
FillArrays = "1"
25+
FunctionImplementations = "0.3"
2326
JLArrays = "0.3"
2427
LinearAlgebra = "1"
2528
MatrixAlgebraKit = "0.2.5, 0.3, 0.4, 0.5, 0.6"
26-
NamedDimsArrays = "0.10, 0.11"
29+
NamedDimsArrays = "0.12"
2730
SafeTestsets = "0.1"
28-
SparseArraysBase = "0.7.10"
31+
SparseArraysBase = "0.8"
2932
StableRNGs = "1"
3033
Suppressor = "0.2"
3134
Test = "1"

test/test_basics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using DerivableInterfaces: permuteddims
21
using DiagonalArrays:
32
DiagonalArrays,
43
ShapeInitializer,
@@ -18,6 +17,7 @@ using DiagonalArrays:
1817
diagview,
1918
getdiagindices
2019
using FillArrays: Fill, Ones, Zeros
20+
using FunctionImplementations: permuteddims
2121
using LinearAlgebra:
2222
Diagonal, det, ishermitian, isposdef, issymmetric, logdet, mul!, pinv, tr
2323
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, sparsezeros, storedlength
@@ -229,7 +229,7 @@ using Test: @test, @test_throws, @testset, @test_broken, @inferred
229229
@test diagview(b) diagview(a)
230230
@test size(b) === (4, 2, 3)
231231
end
232-
@testset "DerivableInterfaces.permuteddims" begin
232+
@testset "FunctionImplementations.permuteddims" begin
233233
a = DiagonalArray(randn(elt, 2), (2, 3, 4))
234234
b = permuteddims(a, (3, 1, 2))
235235
@test diagview(b) diagview(a)

0 commit comments

Comments
 (0)