Skip to content

Commit b1d3723

Browse files
committed
Add source code and tests
1 parent ad98b39 commit b1d3723

File tree

10 files changed

+358
-7
lines changed

10 files changed

+358
-7
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@ uuid = "b8770bf0-c4ae-4888-b9b0-956061873092"
33
authors = ["ITensor developers <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
8+
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
9+
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
612
[compat]
713
Aqua = "0.8.9"
14+
ArrayLayouts = "1.11.0"
15+
BroadcastMapConversion = "0.1.0"
16+
LinearAlgebra = "1.10"
817
SafeTestsets = "0.1"
918
Suppressor = "0.2"
1019
Test = "1.10"
1120
julia = "1.10"
1221

1322
[extras]
1423
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
15-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1624
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1725
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
26+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1827

1928
[targets]
2029
test = ["Aqua", "Test", "Suppressor", "SafeTestsets"]

README.md

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,49 @@ julia> Pkg.add("SparseArraysBase")
3232
## Examples
3333

3434
````julia
35-
using SparseArraysBase: SparseArraysBase
35+
using SparseArraysBase:
36+
SparseArrayDOK,
37+
eachstoredindex,
38+
getstoredindex,
39+
getunstoredindex,
40+
isstored,
41+
setstoredindex!,
42+
setunstoredindex!,
43+
storedlength,
44+
storedpairs,
45+
storedvalues
46+
using Test: @test, @test_throws
47+
48+
a = SparseArrayDOK{Float64}(2, 2)
49+
````
50+
51+
AbstractArray interface:
52+
53+
````julia
54+
a[1, 2] = 12
55+
@test a[1, 1] == 0
56+
@test a[2, 1] == 0
57+
@test a[1, 2] == 12
58+
@test a[2, 2] == 0
3659
````
3760

38-
Examples go here.
61+
SparseArraysBase interface:
62+
63+
````julia
64+
@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
65+
@test getstoredindex(a, 1, 2) == 12
66+
@test_throws KeyError getstoredindex(a, 1, 1)
67+
@test getunstoredindex(a, 1, 1) == 0
68+
@test getunstoredindex(a, 1, 2) == 0
69+
@test !isstored(a, 1, 1)
70+
@test isstored(a, 1, 2)
71+
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
72+
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
73+
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
74+
@test storedlength(a) == 1
75+
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
76+
@test issetequal(storedvalues(a), [12])
77+
````
3978

4079
---
4180

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[deps]
22
SparseArraysBase = "b8770bf0-c4ae-4888-b9b0-956061873092"
3+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

examples/README.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,41 @@ julia> Pkg.add("SparseArraysBase")
3737

3838
# ## Examples
3939

40-
using SparseArraysBase: SparseArraysBase
41-
# Examples go here.
40+
using SparseArraysBase:
41+
SparseArrayDOK,
42+
eachstoredindex,
43+
getstoredindex,
44+
getunstoredindex,
45+
isstored,
46+
setstoredindex!,
47+
setunstoredindex!,
48+
storedlength,
49+
storedpairs,
50+
storedvalues
51+
using Test: @test, @test_throws
52+
53+
a = SparseArrayDOK{Float64}(2, 2)
54+
55+
# AbstractArray interface:
56+
57+
a[1, 2] = 12
58+
@test a[1, 1] == 0
59+
@test a[2, 1] == 0
60+
@test a[1, 2] == 12
61+
@test a[2, 2] == 0
62+
63+
# SparseArraysBase interface:
64+
65+
@test issetequal(eachstoredindex(a), [CartesianIndex(1, 2)])
66+
@test getstoredindex(a, 1, 2) == 12
67+
@test_throws KeyError getstoredindex(a, 1, 1)
68+
@test getunstoredindex(a, 1, 1) == 0
69+
@test getunstoredindex(a, 1, 2) == 0
70+
@test !isstored(a, 1, 1)
71+
@test isstored(a, 1, 2)
72+
@test setstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
73+
@test_throws KeyError setstoredindex!(copy(a), 21, 2, 1)
74+
@test setunstoredindex!(copy(a), 21, 1, 2) == [0 21; 0 0]
75+
@test storedlength(a) == 1
76+
@test issetequal(storedpairs(a), [CartesianIndex(1, 2) => 12])
77+
@test issetequal(storedvalues(a), [12])

src/SparseArraysBase.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module SparseArraysBase
22

3-
# Write your package code here.
3+
include("sparsearrayinterface.jl")
4+
include("wrappers.jl")
5+
include("sparsearraydok.jl")
46

57
end

src/sparsearraydok.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# TODO: Define `AbstractSparseArray`, make this a subtype.
2+
struct SparseArrayDOK{T,N} <: AbstractArray{T,N}
3+
storage::Dict{CartesianIndex{N},T}
4+
size::NTuple{N,Int}
5+
end
6+
7+
function SparseArrayDOK{T}(size::Int...) where {T}
8+
N = length(size)
9+
return SparseArrayDOK{T,N}(Dict{CartesianIndex{N},T}(), size)
10+
end
11+
12+
using Derive: @wrappedtype
13+
# Define `WrappedSparseArrayDOK` and `AnySparseArrayDOK`.
14+
@wrappedtype SparseArrayDOK
15+
16+
using Derive: Derive
17+
function Derive.interface(::Type{<:SparseArrayDOK})
18+
return SparseArrayInterface()
19+
end
20+
21+
using Derive: @derive
22+
@derive AnySparseArrayDOK AbstractArrayOps
23+
24+
storage(a::SparseArrayDOK) = a.storage
25+
Base.size(a::SparseArrayDOK) = a.size
26+
27+
storedvalues(a::SparseArrayDOK) = values(storage(a))
28+
function isstored(a::SparseArrayDOK, I::Int...)
29+
return CartesianIndex(I) in keys(storage(a))
30+
end
31+
function eachstoredindex(a::SparseArrayDOK)
32+
return keys(storage(a))
33+
end
34+
function getstoredindex(a::SparseArrayDOK, I::Int...)
35+
return storage(a)[CartesianIndex(I)]
36+
end
37+
function getunstoredindex(a::SparseArrayDOK, I::Int...)
38+
return zero(eltype(a))
39+
end
40+
function setstoredindex!(a::SparseArrayDOK, value, I::Int...)
41+
isstored(a, I...) || throw(KeyError(CartesianIndex(I)))
42+
storage(a)[CartesianIndex(I)] = value
43+
return a
44+
end
45+
function setunstoredindex!(a::SparseArrayDOK, value, I::Int...)
46+
storage(a)[CartesianIndex(I)] = value
47+
return a
48+
end
49+
50+
# Optional, but faster than the default.
51+
storedpairs(a::SparseArrayDOK) = storage(a)

src/sparsearrayinterface.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Minimal interface for `SparseArrayInterface`.
2+
# TODO: Define default definitions for these based
3+
# on the dense case.
4+
storedvalues(a) = error()
5+
isstored(a, I::Int...) = error()
6+
eachstoredindex(a) = error()
7+
getstoredindex(a, I::Int...) = error()
8+
getunstoredindex(a, I::Int...) = error()
9+
setstoredindex!(a, value, I::Int...) = error()
10+
setunstoredindex!(a, value, I::Int...) = error()
11+
12+
# Derived interface.
13+
storedlength(a) = length(storedvalues(a))
14+
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
15+
16+
function eachstoredindex(a1, a2, a_rest...)
17+
# TODO: Make this more customizable, say with a function
18+
# `combine/promote_storedindices(a1, a2)`.
19+
return union(eachstoredindex.((a1, a2, a_rest...))...)
20+
end
21+
22+
# TODO: Add `ndims` type parameter.
23+
# TODO: Define `AbstractSparseArrayInterface`, make this a subtype.
24+
using Derive: Derive, @interface, AbstractArrayInterface
25+
struct SparseArrayInterface <: AbstractArrayInterface end
26+
27+
# Convenient shorthand to refer to the sparse interface.
28+
const sparse = SparseArrayInterface()
29+
30+
# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
31+
# to handle slicing (implemented by copying SubArray).
32+
@interface sparse function Base.getindex(a, I::Int...)
33+
!isstored(a, I...) && return getunstoredindex(a, I...)
34+
return getstoredindex(a, I...)
35+
end
36+
37+
@interface sparse function Base.setindex!(a, value, I::Int...)
38+
iszero(value) && return a
39+
if !isstored(a, I...)
40+
setunstoredindex!(a, value, I...)
41+
return a
42+
end
43+
setstoredindex!(a, value, I...)
44+
return a
45+
end
46+
47+
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
48+
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
49+
@interface sparse function Base.similar(a, T::Type, size::Tuple{Vararg{Int}})
50+
return SparseArrayDOK{T}(size...)
51+
end
52+
53+
## TODO: Make this more general, handle mixtures of integers and ranges.
54+
## TODO: Make this logic generic to all `similar(::AbstractInterface, ...)`.
55+
## @interface sparse function Base.similar(a, T::Type, dims::Tuple{Vararg{Base.OneTo}})
56+
## return sparse(similar)(interface, a, T, Base.to_shape(dims))
57+
## end
58+
59+
@interface sparse function Base.map(f, as...)
60+
# This is defined in this way so we can rely on the Broadcast logic
61+
# for determining the destination of the operation (element type, shape, etc.).
62+
return f.(as...)
63+
end
64+
65+
@interface sparse function Base.map!(f, dest, as...)
66+
# Check `f` preserves zeros.
67+
# Define as `map_stored!`.
68+
# Define `eachstoredindex` promotion.
69+
for I in eachstoredindex(as...)
70+
dest[I] = f(map(a -> a[I], as)...)
71+
end
72+
return dest
73+
end
74+
75+
# TODO: Define `AbstractSparseArrayStyle`, make this a subtype.
76+
struct SparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
77+
78+
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()
79+
80+
@interface sparse function Broadcast.BroadcastStyle(type::Type)
81+
return SparseArrayStyle{ndims(type)}()
82+
end
83+
84+
function Base.similar(bc::Broadcast.Broadcasted{<:SparseArrayStyle}, T::Type, axes::Tuple)
85+
# TODO: Allow `similar` to accept `axes` directly.
86+
return sparse(similar)(bc, T, Int.(length.(axes)))
87+
end
88+
89+
using BroadcastMapConversion: map_function, map_args
90+
# TODO: Look into `SparseArrays.capturescalars`:
91+
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
92+
function Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{<:SparseArrayStyle})
93+
sparse(map!)(map_function(bc), dest, map_args(bc)...)
94+
return dest
95+
end
96+
97+
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
98+
99+
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
100+
101+
struct SparseLayout <: AbstractSparseLayout end
102+
103+
@interface sparse function ArrayLayouts.MemoryLayout(type::Type)
104+
return SparseLayout()
105+
end
106+
107+
using LinearAlgebra: LinearAlgebra
108+
@interface sparse function LinearAlgebra.mul!(a_dest, a1, a2, α, β)
109+
return ArrayLayouts.mul!(a_dest, a1, a2, α, β)
110+
end
111+
112+
function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
113+
if I1[2] I2[1]
114+
return nothing
115+
end
116+
return CartesianIndex(I1[1], I2[2])
117+
end
118+
119+
function default_mul!!(
120+
a_dest::AbstractMatrix,
121+
a1::AbstractMatrix,
122+
a2::AbstractMatrix,
123+
α::Number=true,
124+
β::Number=false,
125+
)
126+
mul!(a_dest, a1, a2, α, β)
127+
return a_dest
128+
end
129+
130+
function default_mul!!(
131+
a_dest::Number, a1::Number, a2::Number, α::Number=true, β::Number=false
132+
)
133+
return a1 * a2 * α + a_dest * β
134+
end
135+
136+
# a1 * a2 * α + a_dest * β
137+
function sparse_mul!(
138+
a_dest::AbstractArray,
139+
a1::AbstractArray,
140+
a2::AbstractArray,
141+
α::Number=true,
142+
β::Number=false;
143+
(mul!!)=(default_mul!!),
144+
)
145+
for I1 in eachstoredindex(a1)
146+
for I2 in eachstoredindex(a2)
147+
I_dest = mul_indices(I1, I2)
148+
if !isnothing(I_dest)
149+
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
150+
end
151+
end
152+
end
153+
return a_dest
154+
end
155+
156+
function ArrayLayouts.materialize!(
157+
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
158+
)
159+
sparse_mul!(m.C, m.A, m.B, m.α, m.β)
160+
return m.C
161+
end

src/wrappers.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using LinearAlgebra: Adjoint
2+
storedvalues(a::Adjoint) = storedvalues(parent(a))
3+
function isstored(a::Adjoint, i::Int, j::Int)
4+
return isstored(parent(a), j, i)
5+
end
6+
function eachstoredindex(a::Adjoint)
7+
# TODO: Make lazy with `Iterators.map`.
8+
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
9+
end
10+
function getstoredindex(a::Adjoint, i::Int, j::Int)
11+
return getstoredindex(parent(a), j, i)'
12+
end
13+
function getunstoredindex(a::Adjoint, i::Int, j::Int)
14+
return getunstoredindex(parent(a), j, i)'
15+
end
16+
function setstoredindex!(a::Adjoint, value, i::Int, j::Int)
17+
setstoredindex!(parent(a), value', j, i)
18+
return a
19+
end
20+
function setunstoredindex!(a::Adjoint, value, i::Int, j::Int)
21+
setunstoredindex!(parent(a), value', j, i)
22+
return a
23+
end
24+
25+
using LinearAlgebra: Transpose
26+
storedvalues(a::Transpose) = storedvalues(parent(a))
27+
function isstored(a::Transpose, i::Int, j::Int)
28+
return isstored(parent(a), j, i)
29+
end
30+
function eachstoredindex(a::Transpose)
31+
# TODO: Make lazy with `Iterators.map`.
32+
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
33+
end
34+
function getstoredindex(a::Transpose, i::Int, j::Int)
35+
return transpose(getstoredindex(parent(a), j, i))
36+
end
37+
function getunstoredindex(a::Transpose, i::Int, j::Int)
38+
return transpose(getunstoredindex(parent(a), j, i))
39+
end
40+
function setstoredindex!(a::Transpose, value, i::Int, j::Int)
41+
setstoredindex!(parent(a), transpose(value), j, i)
42+
return a
43+
end
44+
function setunstoredindex!(a::Transpose, value, i::Int, j::Int)
45+
setunstoredindex!(parent(a), transpose(value), j, i)
46+
return a
47+
end

0 commit comments

Comments
 (0)