22
33diagview (a:: AbstractDiagonalArray ) = throw (MethodError (diagview, Tuple{typeof (a)}))
44
5- using DerivableInterfaces: DerivableInterfaces, @interface
6- using SparseArraysBase:
7- SparseArraysBase, AbstractSparseArrayInterface, AbstractSparseArrayStyle
5+ using FunctionImplementations: FunctionImplementations
6+ using SparseArraysBase: SparseArraysBase as SA, AbstractSparseArrayStyle
87
9- abstract type AbstractDiagonalArrayInterface{N} <: AbstractSparseArrayInterface{N} end
8+ abstract type AbstractDiagonalArrayStyle <: AbstractSparseArrayStyle end
109
11- struct DiagonalArrayInterface{N} <: AbstractDiagonalArrayInterface{N} end
12- DiagonalArrayInterface {M} (:: Val{N} ) where {M, N} = DiagonalArrayInterface {N} ()
13- DiagionalArrayInterface (:: Val{N} ) where {N} = DiagonalArrayInterface {N} ()
14- DiagonalArrayInterface () = DiagonalArrayInterface {Any} ()
10+ struct DiagonalArrayStyle <: AbstractDiagonalArrayStyle end
11+ const diag_style = DiagonalArrayStyle ()
1512
16- function Base . similar (:: AbstractDiagonalArrayInterface , elt :: Type , ax :: Tuple )
17- return similar (DiagonalArray{elt}, ax )
13+ function FunctionImplementations . Style (:: Type{<:AbstractDiagonalArray} )
14+ return DiagonalArrayStyle ( )
1815end
19- function DerivableInterfaces. interface (:: Type{<:AbstractDiagonalArray{<:Any, N}} ) where {N}
20- return DiagonalArrayInterface {N} ()
21- end
22-
23- abstract type AbstractDiagonalArrayStyle{N} <: AbstractSparseArrayStyle{N} end
2416
25- function DerivableInterfaces. interface (:: Type{<:AbstractDiagonalArrayStyle{N}} ) where {N}
26- return DiagonalArrayInterface {N} ()
17+ module Broadcast
18+ import SparseArraysBase as SA
19+ abstract type AbstractDiagonalArrayStyle{N} <: SA.Broadcast.AbstractSparseArrayStyle{N} end
20+ struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
21+ DiagonalArrayStyle {M} (:: Val{N} ) where {M, N} = DiagonalArrayStyle {N} ()
2722end
2823
29- struct DiagonalArrayStyle{N} <: AbstractDiagonalArrayStyle{N} end
30-
31- DiagonalArrayStyle {M} (:: Val{N} ) where {M, N} = DiagonalArrayStyle {N} ()
32-
33- function SparseArraysBase. isstored (
34- a:: AbstractDiagonalArray{<:Any, N} , I:: Vararg{Int, N}
35- ) where {N}
36- return allequal (I)
37- end
38- function SparseArraysBase. getstoredindex (
39- a:: AbstractDiagonalArray{<:Any, N} , I:: Vararg{Int, N}
24+ using SparseArraysBase: getstoredindex
25+ const getstoredindex_diag = diag_style (getstoredindex)
26+ function getstoredindex_diag (
27+ a:: AbstractArray{<:Any, N} , I:: Vararg{Int, N}
4028 ) where {N}
4129 # TODO : Make this check optional, define `checkstored` like `checkbounds`
4230 # in SparseArraysBase.jl.
4331 # allequal(I) || error("Not a diagonal index.")
4432 return getdiagindex (a, first (I))
4533end
46- function SparseArraysBase . getstoredindex (a:: AbstractDiagonalArray {<:Any, 0} )
34+ function getstoredindex_diag (a:: AbstractArray {<:Any, 0} )
4735 return getdiagindex (a, 1 )
4836end
49- function SparseArraysBase. setstoredindex! (
50- a:: AbstractDiagonalArray{<:Any, N} , value, I:: Vararg{Int, N}
37+ function getstoredindex_diag (a:: AbstractArray , I:: Int... )
38+ return sparse_style (getstoredindex)(a, I... )
39+ end
40+ using SparseArraysBase: setstoredindex!
41+ const setstoredindex!_diag = diag_style (setstoredindex!)
42+ function setstoredindex!_diag (
43+ a:: AbstractArray{<:Any, N} , value, I:: Vararg{Int, N}
5144 ) where {N}
5245 # TODO : Make this check optional, define `checkstored` like `checkbounds`
5346 # in SparseArraysBase.jl.
5447 # allequal(I) || error("Not a diagonal index.")
5548 setdiagindex! (a, value, first (I))
5649 return a
5750end
58- function SparseArraysBase . setstoredindex! (a:: AbstractDiagonalArray {<:Any, 0} , value)
51+ function setstoredindex!_diag (a:: AbstractArray {<:Any, 0} , value)
5952 setdiagindex! (a, value, 1 )
6053 return a
6154end
62- function SparseArraysBase. eachstoredindex (:: IndexCartesian , a:: AbstractDiagonalArray )
55+ using SparseArraysBase: eachstoredindex
56+ const eachstoredindex_diag = diag_style (eachstoredindex)
57+ function eachstoredindex_diag (:: IndexCartesian , a:: AbstractArray )
6358 return diagindices (a)
6459end
6560
@@ -84,8 +79,39 @@ function Base.setindex!(a::AbstractDiagonalArray, value, I::DiagIndex)
8479 return invoke (setindex!, Tuple{AbstractArray, Any, DiagIndex}, a, value, I)
8580end
8681
87- @interface :: AbstractDiagonalArrayInterface function Broadcast. BroadcastStyle (type:: Type )
88- return DiagonalArrayStyle {ndims(type)} ()
82+ using SparseArraysBase: sparse_style
83+ const getindex_diag = diag_style (getindex)
84+ getindex_diag (a:: AbstractArray , I... ) = sparse_style (getindex)(a, I... )
85+ const setindex!_diag = diag_style (setindex!)
86+ setindex!_diag (a:: AbstractArray , value, I... ) = sparse_style (setindex!)(a, value, I... )
87+ const copyto!_diag = diag_style (copyto!)
88+ copyto!_diag (dst:: AbstractArray , src:: AbstractArray ) = sparse_style (copyto!)(dst, src)
89+ const map_diag = diag_style (map)
90+ map_diag (f, as:: AbstractArray... ) = sparse_style (map)(f, as... )
91+ const map!_diag = diag_style (map!)
92+ map!_diag (f, dst:: AbstractArray , as:: AbstractArray... ) = sparse_style (map!)(f, dst, as... )
93+ const fill!_diag = diag_style (fill!)
94+ fill!_diag (a:: AbstractArray , value) = sparse_style (fill!)(a, value)
95+ using FunctionImplementations: zero!
96+ const zero!_diag = diag_style (zero!)
97+ zero!_diag (a:: AbstractArray ) = sparse_style (zero!)(a)
98+ using SparseArraysBase: isstored
99+ const isstored_diag = diag_style (isstored)
100+ function isstored_diag (
101+ a:: AbstractArray{<:Any, N} , I:: Vararg{Int, N}
102+ ) where {N}
103+ return allequal (I)
104+ end
105+ isstored_diag (a:: AbstractArray , I:: Int... ) = sparse_style (isstored)(a, I... )
106+ using SparseArraysBase: storedvalues
107+ const storedvalues_diag = diag_style (storedvalues)
108+ storedvalues_diag (a:: AbstractArray ) = diagview (a)
109+ using SparseArraysBase: storedpairs
110+ const storedpairs_diag = diag_style (storedpairs)
111+ storedpairs_diag (a:: AbstractArray ) = sparse_style (storedpairs)(a)
112+
113+ function Base. Broadcast. BroadcastStyle (type:: Type{<:AbstractDiagonalArray} )
114+ return Broadcast. DiagonalArrayStyle {ndims(type)} ()
89115end
90116
91117using Base. Broadcast: Broadcasted, broadcasted
@@ -99,10 +125,10 @@ function broadcasted_diagview(bc::Broadcasted)
99125 )
100126 return broadcasted (m. f, map (diagview, m. args)... )
101127end
102- function Base. copy (bc:: Broadcasted{<:DiagonalArrayStyle} )
128+ function Base. copy (bc:: Broadcasted{<:Broadcast. DiagonalArrayStyle} )
103129 return DiagonalArray (copy (broadcasted_diagview (bc)), axes (bc))
104130end
105- function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:DiagonalArrayStyle} )
131+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:Broadcast. DiagonalArrayStyle} )
106132 copyto! (diagview (dest), broadcasted_diagview (bc))
107133 return dest
108134end
0 commit comments