Skip to content

Commit 0514aa0

Browse files
committed
Update for Derive.jl v0.3
1 parent 9a432bf commit 0514aa0

File tree

10 files changed

+213
-188
lines changed

10 files changed

+213
-188
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ version = "0.1.0"
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
88
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
99
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
10+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112

1213
[compat]
1314
Aqua = "0.8.9"
1415
ArrayLayouts = "1.11.0"
1516
BroadcastMapConversion = "0.1.0"
17+
Derive = "0.3.0"
18+
Dictionaries = "0.4.3"
1619
LinearAlgebra = "1.10"
1720
SafeTestsets = "0.1"
1821
Suppressor = "0.2"

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Updates for latest Derive.

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
2+
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
23
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
34
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

examples/README.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ a[1, 2] = 12
6363

6464
# SparseArraysBase interface:
6565

66+
using Dictionaries: IndexError
6667
@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
6768
@test getstoredindex(a, 1, 2) == 12
68-
@test_throws KeyError getstoredindex(a, 1, 1)
69+
@test_throws IndexError getstoredindex(a, 1, 1)
6970
@test getunstoredindex(a, 1, 1) == 0
7071
@test getunstoredindex(a, 1, 2) == 0
7172
@test !isstored(a, 1, 1)
7273
@test isstored(a, 1, 2)
7374
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
74-
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
75+
@test_throws IndexError setstoredindex!(copy(a), 21, 2, 1)
7576
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
7677
@test storedlength(a) == 1
7778
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
@@ -80,8 +81,15 @@ a[1, 2] = 12
8081
# AbstractArray functionality:
8182

8283
b = a .+ 2 .* a'
84+
@test b isa SparseArrayDOK{Float64}
8385
@test b == [0 12; 24 0]
8486
@test storedlength(b) == 2
87+
88+
b = permutedims(a, (2, 1))
8589
@test b isa SparseArrayDOK{Float64}
90+
@test b[1, 1] == a[1, 1]
91+
@test b[2, 1] == a[1, 2]
92+
@test b[1, 2] == a[2, 1]
93+
@test b[2, 2] == a[2, 2]
8694

8795
a * a'

src/SparseArraysBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module SparseArraysBase
22

3+
include("abstractsparsearrayinterface.jl")
34
include("sparsearrayinterface.jl")
45
include("wrappers.jl")
56
include("abstractsparsearray.jl")

src/abstractsparsearray.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ function Derive.interface(::Type{<:AbstractSparseArray})
1010
end
1111

1212
using Derive: @derive
13-
# Derive `Base.getindex`, `Base.setindex!`, etc.
14-
@derive AnyAbstractSparseArray AbstractArrayOps
1513

14+
# TODO: These need to be loaded since `AbstractArrayOps`
15+
# includes overloads of functions from these modules.
16+
# Ideally that wouldn't be needed and can be circumvented
17+
# with `GlobalRef`.
18+
using ArrayLayouts: ArrayLayouts
1619
using LinearAlgebra: LinearAlgebra
17-
@derive (T=AnyAbstractSparseVecOrMat,) begin
18-
LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number)
19-
end
2020

21-
using ArrayLayouts: ArrayLayouts
22-
@derive (T=AnyAbstractSparseArray,) begin
23-
ArrayLayouts.MemoryLayout(::Type{<:T})
24-
end
21+
# Derive `Base.getindex`, `Base.setindex!`, etc.
22+
# TODO: Define `AbstractMatrixOps` and overload for
23+
# `AnyAbstractSparseMatrix` and `AnyAbstractSparseVector`,
24+
# which is where matrix multiplication and factorizations
25+
# shoudl go.
26+
@derive AnyAbstractSparseArray AbstractArrayOps
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Minimal interface for `SparseArrayInterface`.
2+
# TODO: Define default definitions for these based
3+
# on the dense case.
4+
storedvalues(a) = error()
5+
isstored(a, I::Int...) = error()
6+
eachstoredindex(a) = error()
7+
getstoredindex(a, I::Int...) = error()
8+
setstoredindex!(a, value, I::Int...) = error()
9+
setunstoredindex!(a, value, I::Int...) = error()
10+
11+
# Interface defaults.
12+
# TODO: Have a fallback that handles element types
13+
# that don't define `zero(::Type)`.
14+
getunstoredindex(a, I::Int...) = zero(eltype(a))
15+
16+
# Derived interface.
17+
storedlength(a) = length(storedvalues(a))
18+
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
19+
20+
function eachstoredindex(a1, a2, a_rest...)
21+
# TODO: Make this more customizable, say with a function
22+
# `combine/promote_storedindices(a1, a2)`.
23+
return union(eachstoredindex.((a1, a2, a_rest...))...)
24+
end
25+
26+
using Derive: Derive, @interface, AbstractArrayInterface
27+
28+
# TODO: Add `ndims` type parameter.
29+
# TODO: This isn't used to define interface functions right now.
30+
# Currently, `@interface` expects an instance, probably it should take a
31+
# type instead so fallback functions can use abstract types.
32+
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end
33+
34+
# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
35+
# to handle slicing (implemented by copying SubArray).
36+
@interface AbstractSparseArrayInterface function Base.getindex(a, I::Int...)
37+
!isstored(a, I...) && return getunstoredindex(a, I...)
38+
return getstoredindex(a, I...)
39+
end
40+
41+
@interface AbstractSparseArrayInterface function Base.setindex!(a, value, I::Int...)
42+
iszero(value) && return a
43+
if !isstored(a, I...)
44+
setunstoredindex!(a, value, I...)
45+
return a
46+
end
47+
setstoredindex!(a, value, I...)
48+
return a
49+
end
50+
51+
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
52+
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
53+
@interface AbstractSparseArrayInterface function Base.similar(
54+
a, T::Type, size::Tuple{Vararg{Int}}
55+
)
56+
# TODO: Define `default_similartype` or something like that?
57+
return SparseArrayDOK{T}(size...)
58+
end
59+
60+
@interface AbstractSparseArrayInterface function Base.map!(f, dest, as...)
61+
# Check `f` preserves zeros.
62+
# Define as `map_stored!`.
63+
# Define `eachstoredindex` promotion.
64+
for I in eachstoredindex(as...)
65+
dest[I] = f(map(a -> a[I], as)...)
66+
end
67+
return dest
68+
end
69+
70+
# TODO: Make this a subtype of `Derive.AbstractArrayStyle{N}` instead.
71+
using Derive: Derive
72+
abstract type AbstractSparseArrayStyle{N} <: Derive.AbstractArrayStyle{N} end
73+
74+
struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
75+
76+
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()
77+
78+
@interface AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
79+
return SparseArrayStyle{ndims(type)}()
80+
end
81+
82+
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
83+
84+
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
85+
86+
struct SparseLayout <: AbstractSparseLayout end
87+
88+
@interface AbstractSparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
89+
return SparseLayout()
90+
end
91+
92+
function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
93+
if I1[2] I2[1]
94+
return nothing
95+
end
96+
return CartesianIndex(I1[1], I2[2])
97+
end
98+
99+
function default_mul!!(
100+
a_dest::AbstractMatrix,
101+
a1::AbstractMatrix,
102+
a2::AbstractMatrix,
103+
α::Number=true,
104+
β::Number=false,
105+
)
106+
mul!(a_dest, a1, a2, α, β)
107+
return a_dest
108+
end
109+
110+
function default_mul!!(
111+
a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false
112+
)
113+
return a1 * a2 * α + a_dest * β
114+
end
115+
116+
# a1 * a2 * α + a_dest * β
117+
function sparse_mul!(
118+
a_dest::AbstractArray,
119+
a1::AbstractArray,
120+
a2::AbstractArray,
121+
α::Number=true,
122+
β::Number=false;
123+
(mul!!)=(default_mul!!),
124+
)
125+
for I1 in eachstoredindex(a1)
126+
for I2 in eachstoredindex(a2)
127+
I_dest = mul_indices(I1, I2)
128+
if !isnothing(I_dest)
129+
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
130+
end
131+
end
132+
end
133+
return a_dest
134+
end
135+
136+
function ArrayLayouts.materialize!(
137+
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
138+
)
139+
sparse_mul!(m.C, m.A, m.B, m.α, m.β)
140+
return m.C
141+
end

src/sparsearraydok.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
# TODO: Rewrite to use `Dictionary`.
2-
struct SparseArrayDOK{T,N} <: AbstractSparseArray{T,N}
3-
storage::Dict{CartesianIndex{N},T}
1+
using Dictionaries: Dictionary, IndexError, set!
2+
3+
function default_getunstoredindex(a::AbstractArray, I::Int...)
4+
return zero(eltype(a))
5+
end
6+
7+
struct SparseArrayDOK{T,N,F} <: AbstractSparseArray{T,N}
8+
storage::Dictionary{CartesianIndex{N},T}
49
size::NTuple{N,Int}
10+
getunstoredindex::F
511
end
612

713
function SparseArrayDOK{T,N}(size::Vararg{Int,N}) where {T,N}
8-
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size)
14+
getunstoredindex = default_getunstoredindex
15+
F = typeof(getunstoredindex)
16+
return SparseArrayDOK{T,N,F}(Dictionary{CartesianIndex{N},T}(), size, getunstoredindex)
917
end
1018

1119
function SparseArrayDOK{T}(size::Int...) where {T}
@@ -30,17 +38,17 @@ function getstoredindex(a::SparseArrayDOK, I::Int...)
3038
return storage(a)[CartesianIndex(I)]
3139
end
3240
function getunstoredindex(a::SparseArrayDOK, I::Int...)
33-
return zero(eltype(a))
41+
return a.getunstoredindex(a, I...)
3442
end
3543
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
36-
isstored(a, I...) || throw(KeyError(CartesianIndex(I)))
44+
isstored(a, I...) || throw(IndexError("key $(CartesianIndex(I)) not found"))
3745
storage(a)[CartesianIndex(I)] = value
3846
return a
3947
end
4048
function setunstoredindex!(a::SparseArrayDOK, value, I::Int...)
41-
storage(a)[CartesianIndex(I)] = value
49+
set!(storage(a), CartesianIndex(I), value)
4250
return a
4351
end
4452

4553
# Optional, but faster than the default.
46-
storedpairs(a::SparseArrayDOK) = storage(a)
54+
storedpairs(a::SparseArrayDOK) = pairs(storage(a))

0 commit comments

Comments
 (0)