Skip to content

Commit e112a01

Browse files
authored
Add support for linear indexing (#10)
1 parent c67b83a commit e112a01

File tree

4 files changed

+124
-22
lines changed

4 files changed

+124
-22
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -14,7 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
Aqua = "0.8.9"
1515
ArrayLayouts = "1.11.0"
1616
BroadcastMapConversion = "0.1.0"
17-
Derive = "0.3.0"
17+
Derive = "0.3.6"
1818
Dictionaries = "0.4.3"
1919
LinearAlgebra = "1.10"
2020
SafeTestsets = "0.1"

src/abstractsparsearrayinterface.jl

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,27 @@ end
104104
# type instead so fallback functions can use abstract types.
105105
abstract type AbstractSparseArrayInterface <: AbstractArrayInterface end
106106

107+
function Derive.combine_interface_rule(
108+
interface1::AbstractSparseArrayInterface, interface2::AbstractSparseArrayInterface
109+
)
110+
return error("Rule not defined.")
111+
end
112+
function Derive.combine_interface_rule(
113+
interface1::Interface, interface2::Interface
114+
) where {Interface<:AbstractSparseArrayInterface}
115+
return interface1
116+
end
117+
function Derive.combine_interface_rule(
118+
interface1::AbstractSparseArrayInterface, interface2::AbstractArrayInterface
119+
)
120+
return interface1
121+
end
122+
function Derive.combine_interface_rule(
123+
interface1::AbstractArrayInterface, interface2::AbstractSparseArrayInterface
124+
)
125+
return interface2
126+
end
127+
107128
to_vec(x) = vec(collect(x))
108129
to_vec(x::AbstractArray) = vec(x)
109130

@@ -178,7 +199,46 @@ end
178199
return SparseArrayDOK{T}(size...)
179200
end
180201

181-
@interface ::AbstractSparseArrayInterface function Base.map!(
202+
# Only map the stored values of the inputs.
203+
function map_stored! end
204+
205+
@interface interface::AbstractArrayInterface function map_stored!(
206+
f, a_dest::AbstractArray, as::AbstractArray...
207+
)
208+
for I in eachstoredindex(as...)
209+
a_dest[I] = f(map(a -> a[I], as)...)
210+
end
211+
return a_dest
212+
end
213+
214+
# Only map all values, not just the stored ones.
215+
function map_all! end
216+
217+
@interface interface::AbstractArrayInterface function map_all!(
218+
f, a_dest::AbstractArray, as::AbstractArray...
219+
)
220+
for I in eachindex(as...)
221+
a_dest[I] = map(f, map(a -> a[I], as)...)
222+
end
223+
return a_dest
224+
end
225+
226+
using ArrayLayouts: ArrayLayouts, zero!
227+
228+
# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
229+
# and is useful for sparse array logic, since it can be used to empty
230+
# the sparse array storage.
231+
# We use a single function definition to minimize method ambiguities.
232+
@interface interface::AbstractSparseArrayInterface function ArrayLayouts.zero!(
233+
a::AbstractArray
234+
)
235+
# More generally, this codepath could be taking if `zero(eltype(a))`
236+
# is defined and the elements are immutable.
237+
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
238+
return @interface interface map_stored!(f, a, a)
239+
end
240+
241+
@interface interface::AbstractSparseArrayInterface function Base.map!(
182242
f, a_dest::AbstractArray, as::AbstractArray...
183243
)
184244
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
@@ -194,15 +254,22 @@ end
194254
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
195255
if !preserves_unstored
196256
# Doesn't preserve unstored values, loop over all elements.
197-
for I in eachindex(as...)
198-
a_dest[I] = map(f, map(a -> a[I], as)...)
199-
end
257+
@interface interface map_all!(f, a_dest, as...)
200258
return a_dest
201259
end
202-
# Define `eachstoredindex` promotion.
203-
for I in eachstoredindex(as...)
204-
a_dest[I] = f(map(a -> a[I], as)...)
205-
end
260+
# First zero out the destination.
261+
# TODO: Make this more nuanced, skip when possible, for
262+
# example if the sparsity of the destination is a subset of
263+
# the sparsity of the sources, i.e.:
264+
# ```julia
265+
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
266+
# zero!(a_dest)
267+
# end
268+
# ```
269+
# This is the safest thing to do in general, for example
270+
# if the destination is dense but the sources are sparse.
271+
@interface interface zero!(a_dest)
272+
@interface interface map_stored!(f, a_dest, as...)
206273
return a_dest
207274
end
208275

src/sparsearrayinterface.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@ using Derive: Derive
22

33
struct SparseArrayInterface <: AbstractSparseArrayInterface end
44

5+
# Fix ambiguity error.
6+
function Derive.combine_interface_rule(::SparseArrayInterface, ::SparseArrayInterface)
7+
return SparseArrayInterface()
8+
end
9+
function Derive.combine_interface_rule(
10+
interface1::SparseArrayInterface, interface2::AbstractSparseArrayInterface
11+
)
12+
return interface1
13+
end
14+
function Derive.combine_interface_rule(
15+
interface1::AbstractSparseArrayInterface, interface2::SparseArrayInterface
16+
)
17+
return interface2
18+
end
19+
520
# Convenient shorthand to refer to the sparse interface.
621
# Can turn a function into a sparse function with the syntax `sparse(f)`,
722
# i.e. `sparse(map)(x -> 2x, randn(2, 2))` while use the sparse

src/wrappers.jl

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@ parentvalue_to_value(a::AbstractArray, value) = value
22
value_to_parentvalue(a::AbstractArray, value) = value
33
eachstoredparentindex(a::AbstractArray) = eachstoredindex(parent(a))
44
storedparentvalues(a::AbstractArray) = storedvalues(parent(a))
5-
parentindex_to_index(a::AbstractArray, I::CartesianIndex) = error()
6-
function parentindex_to_index(a::AbstractArray, I::Int...)
5+
6+
function parentindex_to_index(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
7+
return throw(MethodError(parentindex_to_index, Tuple{typeof(a),typeof(I)}))
8+
end
9+
function parentindex_to_index(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
710
return Tuple(parentindex_to_index(a, CartesianIndex(I)))
811
end
9-
index_to_parentindex(a::AbstractArray, I::CartesianIndex) = error()
10-
function index_to_parentindex(a::AbstractArray, I::Int...)
12+
# Handle linear indexing.
13+
function parentindex_to_index(a::AbstractArray, I::Int)
14+
return parentindex_to_index(a, CartesianIndices(parent(a))[I])
15+
end
16+
17+
function index_to_parentindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
18+
return throw(MethodError(index_to_parentindex, Tuple{typeof(a),typeof(I)}))
19+
end
20+
function index_to_parentindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
1121
return Tuple(index_to_parentindex(a, CartesianIndex(I)))
1222
end
23+
# Handle linear indexing.
24+
function index_to_parentindex(a::AbstractArray, I::Int)
25+
return index_to_parentindex(a, CartesianIndices(a)[I])
26+
end
1327

1428
function cartesianindex_reverse(I::CartesianIndex)
1529
return CartesianIndex(reverse(Tuple(I)))
@@ -21,10 +35,10 @@ tuple_oneto(n) = ntuple(identity, n)
2135
genperm(v, perm) = map(j -> v[j], perm)
2236

2337
using LinearAlgebra: Adjoint
24-
function parentindex_to_index(a::Adjoint, I::CartesianIndex)
38+
function parentindex_to_index(a::Adjoint, I::CartesianIndex{2})
2539
return cartesianindex_reverse(I)
2640
end
27-
function index_to_parentindex(a::Adjoint, I::CartesianIndex)
41+
function index_to_parentindex(a::Adjoint, I::CartesianIndex{2})
2842
return cartesianindex_reverse(I)
2943
end
3044
function parentvalue_to_value(a::Adjoint, value)
@@ -36,18 +50,18 @@ end
3650

3751
perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p
3852
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip
39-
function index_to_parentindex(a::PermutedDimsArray, I::CartesianIndex)
53+
function index_to_parentindex(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N}
4054
return CartesianIndex(genperm(I, iperm(a)))
4155
end
42-
function parentindex_to_index(a::PermutedDimsArray, I::CartesianIndex)
56+
function parentindex_to_index(a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}) where {N}
4357
return CartesianIndex(genperm(I, perm(a)))
4458
end
4559

4660
using Base: ReshapedArray
47-
function parentindex_to_index(a::ReshapedArray, I::CartesianIndex)
61+
function parentindex_to_index(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N}
4862
return CartesianIndices(size(a))[LinearIndices(parent(a))[I]]
4963
end
50-
function index_to_parentindex(a::ReshapedArray, I::CartesianIndex)
64+
function index_to_parentindex(a::ReshapedArray{<:Any,N}, I::CartesianIndex{N}) where {N}
5165
return CartesianIndices(parent(a))[LinearIndices(size(a))[I]]
5266
end
5367

@@ -56,9 +70,15 @@ function eachstoredparentindex(a::SubArray)
5670
return all(d -> I[d] parentindices(a)[d], 1:ndims(parent(a)))
5771
end
5872
end
73+
# Don't constrain the number of dimensions of the array
74+
# and index since the parent array can have a different
75+
# number of dimensions than the `SubArray`.
5976
function index_to_parentindex(a::SubArray, I::CartesianIndex)
6077
return CartesianIndex(Base.reindex(parentindices(a), Tuple(I)))
6178
end
79+
# Don't constrain the number of dimensions of the array
80+
# and index since the parent array can have a different
81+
# number of dimensions than the `SubArray`.
6282
function parentindex_to_index(a::SubArray, I::CartesianIndex)
6383
nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d
6484
return !(parentindices(a)[d] isa Real)
@@ -81,10 +101,10 @@ function storedparentvalues(a::SubArray)
81101
end
82102

83103
using LinearAlgebra: Transpose
84-
function parentindex_to_index(a::Transpose, I::CartesianIndex)
104+
function parentindex_to_index(a::Transpose, I::CartesianIndex{2})
85105
return cartesianindex_reverse(I)
86106
end
87-
function index_to_parentindex(a::Transpose, I::CartesianIndex)
107+
function index_to_parentindex(a::Transpose, I::CartesianIndex{2})
88108
return cartesianindex_reverse(I)
89109
end
90110
function parentvalue_to_value(a::Transpose, value)

0 commit comments

Comments
 (0)