Skip to content

Commit 3d774aa

Browse files
authored
Fix issues with in-place map/broadcast (#21)
1 parent a57e030 commit 3d774aa

File tree

3 files changed

+76
-38
lines changed

3 files changed

+76
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/abstractsparsearrayinterface.jl

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,25 @@ end
203203
return SparseArrayDOK{T}(size...)
204204
end
205205

206+
# map over a specified subset of indices of the inputs.
207+
function map_indices! end
208+
209+
@interface interface::AbstractArrayInterface function map_indices!(
210+
indices, f, a_dest::AbstractArray, as::AbstractArray...
211+
)
212+
for I in indices
213+
a_dest[I] = f(map(a -> a[I], as)...)
214+
end
215+
return a_dest
216+
end
217+
206218
# Only map the stored values of the inputs.
207219
function map_stored! end
208220

209221
@interface interface::AbstractArrayInterface function map_stored!(
210222
f, a_dest::AbstractArray, as::AbstractArray...
211223
)
212-
for I in eachstoredindex(as...)
213-
a_dest[I] = f(map(a -> a[I], as)...)
214-
end
224+
@interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...)
215225
return a_dest
216226
end
217227

@@ -221,9 +231,7 @@ function map_all! end
221231
@interface interface::AbstractArrayInterface function map_all!(
222232
f, a_dest::AbstractArray, as::AbstractArray...
223233
)
224-
for I in eachindex(as...)
225-
a_dest[I] = map(f, map(a -> a[I], as)...)
226-
end
234+
@interface interface map_indices!(eachindex(as...), f, a_dest, as...)
227235
return a_dest
228236
end
229237

@@ -242,38 +250,32 @@ using ArrayLayouts: ArrayLayouts, zero!
242250
return @interface interface map_stored!(f, a, a)
243251
end
244252

253+
# Determines if a function preserves the stored values
254+
# of the destination sparse array.
255+
# The current code may be inefficient since it actually
256+
# accesses an unstored element, which in the case of a
257+
# sparse array of arrays can allocate an array.
258+
# Sparse arrays could be expected to define a cheap
259+
# unstored element allocator, for example
260+
# `get_prototypical_unstored(a::AbstractArray)`.
261+
function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
262+
I = first(eachindex(as...))
263+
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
264+
end
265+
245266
@interface interface::AbstractSparseArrayInterface function Base.map!(
246267
f, a_dest::AbstractArray, as::AbstractArray...
247268
)
248-
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
249-
# to determine if a function preserves the stored values
250-
# of the destination sparse array.
251-
# The current code may be inefficient since it actually
252-
# accesses an unstored element, which in the case of a
253-
# sparse array of arrays can allocate an array.
254-
# Sparse arrays could be expected to define a cheap
255-
# unstored element allocator, for example
256-
# `get_prototypical_unstored(a::AbstractArray)`.
257-
I = first(eachindex(as...))
258-
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
259-
if !preserves_unstored
260-
# Doesn't preserve unstored values, loop over all elements.
261-
@interface interface map_all!(f, a_dest, as...)
262-
return a_dest
269+
indices = if !preserves_unstored(f, a_dest, as...)
270+
eachindex(a_dest)
271+
elseif any(a -> a_dest !== a, as)
272+
as = map(a -> Base.unalias(a_dest, a), as)
273+
@interface interface zero!(a_dest)
274+
eachstoredindex(as...)
275+
else
276+
eachstoredindex(a_dest)
263277
end
264-
# First zero out the destination.
265-
# TODO: Make this more nuanced, skip when possible, for
266-
# example if the sparsity of the destination is a subset of
267-
# the sparsity of the sources, i.e.:
268-
# ```julia
269-
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
270-
# zero!(a_dest)
271-
# end
272-
# ```
273-
# This is the safest thing to do in general, for example
274-
# if the destination is dense but the sources are sparse.
275-
@interface interface zero!(a_dest)
276-
@interface interface map_stored!(f, a_dest, as...)
278+
@interface interface map_indices!(indices, f, a_dest, as...)
277279
return a_dest
278280
end
279281

@@ -357,9 +359,7 @@ function sparse_mul!(
357359
β::Number=false;
358360
(mul!!)=(default_mul!!),
359361
)
360-
# TODO: Change to: `a_dest .*= β`
361-
# once https://github.com/ITensor/SparseArraysBase.jl/issues/19 is fixed.
362-
storedvalues(a_dest) .*= β
362+
a_dest .*= β
363363
β′ = one(Bool)
364364
for I1 in eachstoredindex(a1)
365365
for I2 in eachstoredindex(a2)

test/basics/test_sparsearraydok.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,42 @@ arrayts = (Array,)
177177
a[1, 2] = 12
178178
@test sprint(show, "text/plain", a) == "$(summary(a)):\n$(eltype(a)(12))\n ⋅ ⋅"
179179
end
180+
181+
# Regression test for:
182+
# https://github.com/ITensor/SparseArraysBase.jl/issues/19
183+
a = SparseArrayDOK{elt}(2, 2)
184+
a[1, 1] = 1
185+
a .*= 2
186+
@test a == [2 0; 0 0]
187+
@test storedlength(a) == 1
188+
189+
# Test aliasing behavior.
190+
a = SparseArrayDOK{elt}(2, 2)
191+
a[1, 1] = 11
192+
a[1, 2] = 12
193+
a[2, 2] = 22
194+
c1 = @view a[:, 1]
195+
r1 = @view a[1, :]
196+
r1 .= c1
197+
@test c1 == [11, 0]
198+
@test storedlength(c1) == 1
199+
@test r1 == [11, 0]
200+
@test storedlength(r1) == 2
201+
@test a == [11 0; 0 22]
202+
@test storedlength(a) == 3
203+
204+
# Test aliasing behavior.
205+
a = SparseArrayDOK{elt}(2, 2)
206+
a[1, 1] = 11
207+
a[1, 2] = 12
208+
a[2, 2] = 22
209+
c1 = @view a[:, 1]
210+
r1 = @view a[1, :]
211+
c1 .= r1
212+
@test c1 == [11, 12]
213+
@test storedlength(c1) == 2
214+
@test r1 == [11, 12]
215+
@test storedlength(r1) == 2
216+
@test a == [11 12; 12 22]
217+
@test storedlength(a) == 4
180218
end

0 commit comments

Comments
 (0)