Skip to content

Commit 261c2ec

Browse files
committed
Slicing, trailing indices
1 parent 0514aa0 commit 261c2ec

File tree

6 files changed

+157
-76
lines changed

6 files changed

+157
-76
lines changed

examples/README.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ julia> Pkg.add("SparseArraysBase")
3939

4040
using SparseArraysBase:
4141
SparseArrayDOK,
42+
SparseMatrixDOK,
43+
SparseVectorDOK,
4244
eachstoredindex,
4345
getstoredindex,
4446
getunstoredindex,
@@ -81,15 +83,24 @@ using Dictionaries: IndexError
8183
# AbstractArray functionality:
8284

8385
b = a .+ 2 .* a'
84-
@test b isa SparseArrayDOK{Float64}
86+
@test b isa SparseMatrixDOK{Float64}
8587
@test b == [0 12; 24 0]
8688
@test storedlength(b) == 2
8789

8890
b = permutedims(a, (2, 1))
89-
@test b isa SparseArrayDOK{Float64}
91+
@test b isa SparseMatrixDOK{Float64}
9092
@test b[1, 1] == a[1, 1]
9193
@test b[2, 1] == a[1, 2]
9294
@test b[1, 2] == a[2, 1]
9395
@test b[2, 2] == a[2, 2]
9496

95-
a * a'
97+
b = a * a'
98+
@test b isa SparseMatrixDOK{Float64}
99+
@test b == [144 0; 0 0]
100+
@test storedlength(b) == 1
101+
102+
# Second column.
103+
b = a[1:2, 2]
104+
@test b isa SparseVectorDOK{Float64}
105+
@test b == [12, 0]
106+
@test storedlength(b) == 1

src/abstractsparsearray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ using LinearAlgebra: LinearAlgebra
2222
# TODO: Define `AbstractMatrixOps` and overload for
2323
# `AnyAbstractSparseMatrix` and `AnyAbstractSparseVector`,
2424
# which is where matrix multiplication and factorizations
25-
# shoudl go.
25+
# should go.
2626
@derive AnyAbstractSparseArray AbstractArrayOps

src/abstractsparsearrayinterface.jl

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,66 @@
11
# Minimal interface for `SparseArrayInterface`.
22
# TODO: Define default definitions for these based
33
# on the dense case.
4+
# TODO: Define as `MethodError`.
45
storedvalues(a) = error()
56
isstored(a, I::Int...) = error()
67
eachstoredindex(a) = error()
78
getstoredindex(a, I::Int...) = error()
9+
getunstoredindex(a, I::Int...) = error()
810
setstoredindex!(a, value, I::Int...) = error()
911
setunstoredindex!(a, value, I::Int...) = error()
1012

13+
# TODO: Use `Base.to_indices`?
14+
isstored(a::AbstractArray, I::CartesianIndex) = isstored(a, Tuple(I)...)
15+
getstoredindex(a::AbstractArray, I::CartesianIndex) = getstoredindex(a, Tuple(I)...)
16+
getunstoredindex(a::AbstractArray, I::CartesianIndex) = getunstoredindex(a, Tuple(I)...)
17+
function setstoredindex!(a::AbstractArray, value, I::CartesianIndex)
18+
return setstoredindex!(a, value, Tuple(I)...)
19+
end
20+
function setunstoredindex!(a::AbstractArray, value, I::CartesianIndex)
21+
return setunstoredindex!(a, value, Tuple(I)...)
22+
end
23+
1124
# Interface defaults.
1225
# TODO: Have a fallback that handles element types
1326
# that don't define `zero(::Type)`.
1427
getunstoredindex(a, I::Int...) = zero(eltype(a))
1528

1629
# Derived interface.
17-
storedlength(a) = length(storedvalues(a))
18-
storedpairs(a) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
30+
storedlength(a::AbstractArray) = length(storedvalues(a))
31+
storedpairs(a::AbstractArray) = map(I -> I => getstoredindex(a, I), eachstoredindex(a))
32+
function storedvalues(a::AbstractArray)
33+
return @view a[collect(eachstoredindex(a))]
34+
end
1935

2036
function eachstoredindex(a1, a2, a_rest...)
2137
# TODO: Make this more customizable, say with a function
2238
# `combine/promote_storedindices(a1, a2)`.
2339
return union(eachstoredindex.((a1, a2, a_rest...))...)
2440
end
2541

26-
using Derive: Derive, @interface, AbstractArrayInterface
42+
using Derive: Derive, @derive, @interface, AbstractArrayInterface
2743

2844
# TODO: Add `ndims` type parameter.
2945
# TODO: This isn't used to define interface functions right now.
3046
# Currently, `@interface` expects an instance, probably it should take a
3147
# type instead so fallback functions can use abstract types.
3248
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end
3349

34-
# TODO: Use `ArrayLayouts.layout_getindex`, `ArrayLayouts.sub_materialize`
35-
# to handle slicing (implemented by copying SubArray).
36-
@interface AbstractSparseArrayInterface function Base.getindex(a, I::Int...)
50+
# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing
51+
# indices and linear indices.
52+
@interface ::AbstractSparseArrayInterface function Base.getindex(
53+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
54+
) where {N}
3755
!isstored(a, I...) && return getunstoredindex(a, I...)
3856
return getstoredindex(a, I...)
3957
end
4058

41-
@interface AbstractSparseArrayInterface function Base.setindex!(a, value, I::Int...)
59+
# We restrict to `I::Vararg{Int,N}` to allow more general functions to handle trailing
60+
# indices and linear indices.
61+
@interface ::AbstractSparseArrayInterface function Base.setindex!(
62+
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
63+
) where {N}
4264
iszero(value) && return a
4365
if !isstored(a, I...)
4466
setunstoredindex!(a, value, I...)
@@ -50,14 +72,16 @@ end
5072

5173
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
5274
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
53-
@interface AbstractSparseArrayInterface function Base.similar(
54-
a, T::Type, size::Tuple{Vararg{Int}}
75+
@interface ::AbstractSparseArrayInterface function Base.similar(
76+
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
5577
)
5678
# TODO: Define `default_similartype` or something like that?
5779
return SparseArrayDOK{T}(size...)
5880
end
5981

60-
@interface AbstractSparseArrayInterface function Base.map!(f, dest, as...)
82+
@interface ::AbstractSparseArrayInterface function Base.map!(
83+
f, dest::AbstractArray, as::AbstractArray...
84+
)
6185
# Check `f` preserves zeros.
6286
# Define as `map_stored!`.
6387
# Define `eachstoredindex` promotion.
@@ -67,26 +91,26 @@ end
6791
return dest
6892
end
6993

70-
# TODO: Make this a subtype of `Derive.AbstractArrayStyle{N}` instead.
71-
using Derive: Derive
72-
abstract type AbstractSparseArrayStyle{N} <: Derive.AbstractArrayStyle{N} end
94+
abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
95+
96+
@derive AbstractSparseArrayStyle AbstractArrayStyleOps
7397

7498
struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
7599

76100
SparseArrayStyle{M}(::Val{N}) where {M,N} = SparseArrayStyle{N}()
77101

78-
@interface AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
102+
@interface ::AbstractSparseArrayInterface function Broadcast.BroadcastStyle(type::Type)
79103
return SparseArrayStyle{ndims(type)}()
80104
end
81105

82106
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
83107

84108
abstract type AbstractSparseLayout <: ArrayLayouts.MemoryLayout end
85109

86-
struct SparseLayout <: AbstractSparseLayout end
87-
88-
@interface AbstractSparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
89-
return SparseLayout()
110+
function ArrayLayouts.sub_materialize(::AbstractSparseLayout, a::AbstractArray, axes::Tuple)
111+
a_dest = similar(a)
112+
a_dest .= a
113+
return a_dest
90114
end
91115

92116
function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
@@ -139,3 +163,9 @@ function ArrayLayouts.materialize!(
139163
sparse_mul!(m.C, m.A, m.B, m.α, m.β)
140164
return m.C
141165
end
166+
167+
struct SparseLayout <: AbstractSparseLayout end
168+
169+
@interface ::AbstractSparseArrayInterface function ArrayLayouts.MemoryLayout(type::Type)
170+
return SparseLayout()
171+
end

src/sparsearraydok.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@ struct SparseArrayDOK{T,N,F} <: AbstractSparseArray{T,N}
1010
getunstoredindex::F
1111
end
1212

13+
using Derive: Derive
14+
# This defines the destination type of various operations in Derive.jl.
15+
Derive.arraytype(::AbstractSparseArrayInterface, T::Type) = SparseArrayDOK{T}
16+
1317
function SparseArrayDOK{T,N}(size::Vararg{Int,N}) where {T,N}
1418
getunstoredindex = default_getunstoredindex
1519
F = typeof(getunstoredindex)
1620
return SparseArrayDOK{T,N,F}(Dictionary{CartesianIndex{N},T}(), size, getunstoredindex)
1721
end
1822

23+
function SparseArrayDOK{T}(::UndefInitializer, size::Tuple{Vararg{Int}}) where {T}
24+
return SparseArrayDOK{T,length(size)}(size...)
25+
end
26+
1927
function SparseArrayDOK{T}(size::Int...) where {T}
2028
return SparseArrayDOK{T,length(size)}(size...)
2129
end

src/sparsearrayinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ struct SparseArrayInterface <: AbstractSparseArrayInterface end
88
# version of `map`.
99
# const sparse = SparseArrayInterface()
1010

11-
Derive.interface(::AbstractSparseArrayStyle) = SparseArrayInterface()
11+
Derive.interface(::Type{<:AbstractSparseArrayStyle}) = SparseArrayInterface()

src/wrappers.jl

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,102 @@
1+
function cartesianindex_reverse(I::CartesianIndex)
2+
return CartesianIndex(reverse(Tuple(I)))
3+
end
4+
tuple_oneto(n) = ntuple(identity, n)
5+
6+
# TODO: Use `Base.PermutedDimsArrays.genperm` or
7+
# https://github.com/jipolanco/StaticPermutations.jl?
8+
genperm(v, perm) = map(j -> v[j], perm)
9+
10+
## TODO: Use this and something similar for `Dictionary` to make a faster
11+
## implementation of `storedvalues(::SubArray)`.
12+
## function valuesview(d::Dict, keys)
13+
## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]]
14+
## end
15+
16+
function eachstoredparentindex(a::SubArray)
17+
return filter(eachstoredindex(parent(a))) do I
18+
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
19+
end
20+
end
21+
function storedvalues(a::SubArray)
22+
return @view parent(a)[collect(eachstoredparentindex(a))]
23+
end
24+
function isstored(a::SubArray, I::Int...)
25+
return isstored(parent(a), Base.reindex(parentindices(a), I)...)
26+
end
27+
function getstoredindex(a::SubArray, I::Int...)
28+
return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
29+
end
30+
function getunstoredindex(a::SubArray, I::Int...)
31+
return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...)
32+
end
33+
function eachstoredindex(a::SubArray)
34+
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
35+
return !(parentindices(a)[d] isa Real)
36+
end
37+
return collect((
38+
CartesianIndex(
39+
map(nonscalardims) do d
40+
return findfirst(==(I[d]), parentindices(a)[d])
41+
end,
42+
) for I in eachstoredparentindex(a)
43+
))
44+
end
45+
46+
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
47+
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
48+
149
storedvalues(a::PermutedDimsArray) = storedvalues(parent(a))
250
function isstored(a::PermutedDimsArray, I::Int...)
3-
return isstored(parent(a), reverse(I)...)
4-
end
5-
function eachstoredindex(a::PermutedDimsArray)
6-
# TODO: Make lazy with `Iterators.map`.
7-
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
51+
return isstored(parent(a), genperm(I, iperm(a))...)
852
end
953
function getstoredindex(a::PermutedDimsArray, I::Int...)
10-
return getstoredindex(parent(a), reverse(I)...)
54+
return getstoredindex(parent(a), genperm(I, iperm(a))...)
1155
end
1256
function getunstoredindex(a::PermutedDimsArray, I::Int...)
13-
return getunstoredindex(parent(a), reverse(I)...)
57+
return getunstoredindex(parent(a), genperm(I, iperm(a))...)
1458
end
1559
function setstoredindex!(a::PermutedDimsArray, value, I::Int...)
16-
setstoredindex!(parent(a), value, reverse(I)...)
60+
# TODO: Should this be `iperm(a)`?
61+
setstoredindex!(parent(a), value, genperm(I, perm(a))...)
1762
return a
1863
end
1964
function setunstoredindex!(a::PermutedDimsArray, value, I::Int...)
20-
setunstoredindex!(parent(a), value, reverse(I)...)
65+
# TODO: Should this be `iperm(a)`?
66+
setunstoredindex!(parent(a), value, genperm(I, perm(a))...)
2167
return a
2268
end
23-
24-
using LinearAlgebra: Adjoint
25-
storedvalues(a::Adjoint) = storedvalues(parent(a))
26-
function isstored(a::Adjoint, i::Int, j::Int)
27-
return isstored(parent(a), j, i)
28-
end
29-
function eachstoredindex(a::Adjoint)
69+
function eachstoredindex(a::PermutedDimsArray)
3070
# TODO: Make lazy with `Iterators.map`.
31-
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
32-
end
33-
function getstoredindex(a::Adjoint, i::Int, j::Int)
34-
return getstoredindex(parent(a), j, i)'
35-
end
36-
function getunstoredindex(a::Adjoint, i::Int, j::Int)
37-
return getunstoredindex(parent(a), j, i)'
38-
end
39-
function setstoredindex!(a::Adjoint, value, i::Int, j::Int)
40-
setstoredindex!(parent(a), value', j, i)
41-
return a
42-
end
43-
function setunstoredindex!(a::Adjoint, value, i::Int, j::Int)
44-
setunstoredindex!(parent(a), value', j, i)
45-
return a
71+
return map(collect(eachstoredindex(parent(a)))) do I
72+
return CartesianIndex(genperm(I, perm(a)))
73+
end
4674
end
4775

48-
using LinearAlgebra: Transpose
49-
storedvalues(a::Transpose) = storedvalues(parent(a))
50-
function isstored(a::Transpose, i::Int, j::Int)
51-
return isstored(parent(a), j, i)
52-
end
53-
function eachstoredindex(a::Transpose)
54-
# TODO: Make lazy with `Iterators.map`.
55-
return map(CartesianIndex reverse Tuple, collect(eachstoredindex(parent(a))))
56-
end
57-
function getstoredindex(a::Transpose, i::Int, j::Int)
58-
return transpose(getstoredindex(parent(a), j, i))
59-
end
60-
function getunstoredindex(a::Transpose, i::Int, j::Int)
61-
return transpose(getunstoredindex(parent(a), j, i))
62-
end
63-
function setstoredindex!(a::Transpose, value, i::Int, j::Int)
64-
setstoredindex!(parent(a), transpose(value), j, i)
65-
return a
66-
end
67-
function setunstoredindex!(a::Transpose, value, i::Int, j::Int)
68-
setunstoredindex!(parent(a), transpose(value), j, i)
69-
return a
76+
for (type, func) in ((:Adjoint, :adjoint), (:Transpose, :transpose))
77+
@eval begin
78+
using LinearAlgebra: $type
79+
storedvalues(a::$type) = storedvalues(parent(a))
80+
function isstored(a::$type, i::Int, j::Int)
81+
return isstored(parent(a), j, i)
82+
end
83+
function eachstoredindex(a::$type)
84+
# TODO: Make lazy with `Iterators.map`.
85+
return map(cartesianindex_reverse, collect(eachstoredindex(parent(a))))
86+
end
87+
function getstoredindex(a::$type, i::Int, j::Int)
88+
return $func(getstoredindex(parent(a), j, i))
89+
end
90+
function getunstoredindex(a::$type, i::Int, j::Int)
91+
return $func(getunstoredindex(parent(a), j, i))
92+
end
93+
function setstoredindex!(a::$type, value, i::Int, j::Int)
94+
setstoredindex!(parent(a), $func(value), j, i)
95+
return a
96+
end
97+
function setunstoredindex!(a::$type, value, i::Int, j::Int)
98+
setunstoredindex!(parent(a), $func(value), j, i)
99+
return a
100+
end
101+
end
70102
end

0 commit comments

Comments
 (0)