diff --git a/Project.toml b/Project.toml index d93698d..8a19c36 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.2.8" +version = "0.2.9" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index 6d83385..f8bd0e0 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -203,15 +203,25 @@ end return SparseArrayDOK{T}(size...) end +# map over a specified subset of indices of the inputs. +function map_indices! end + +@interface interface::AbstractArrayInterface function map_indices!( + indices, f, a_dest::AbstractArray, as::AbstractArray... +) + for I in indices + a_dest[I] = f(map(a -> a[I], as)...) + end + return a_dest +end + # Only map the stored values of the inputs. function map_stored! end @interface interface::AbstractArrayInterface function map_stored!( f, a_dest::AbstractArray, as::AbstractArray... ) - for I in eachstoredindex(as...) - a_dest[I] = f(map(a -> a[I], as)...) - end + @interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...) return a_dest end @@ -221,9 +231,7 @@ function map_all! end @interface interface::AbstractArrayInterface function map_all!( f, a_dest::AbstractArray, as::AbstractArray... ) - for I in eachindex(as...) - a_dest[I] = map(f, map(a -> a[I], as)...) - end + @interface interface map_indices!(eachindex(as...), f, a_dest, as...) return a_dest end @@ -242,38 +250,32 @@ using ArrayLayouts: ArrayLayouts, zero! return @interface interface map_stored!(f, a, a) end +# Determines if a function preserves the stored values +# of the destination sparse array. +# The current code may be inefficient since it actually +# accesses an unstored element, which in the case of a +# sparse array of arrays can allocate an array. +# Sparse arrays could be expected to define a cheap +# unstored element allocator, for example +# `get_prototypical_unstored(a::AbstractArray)`. +function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...) + I = first(eachindex(as...)) + return iszero(f(map(a -> getunstoredindex(a, I), as)...)) +end + @interface interface::AbstractSparseArrayInterface function Base.map!( f, a_dest::AbstractArray, as::AbstractArray... ) - # TODO: Define a function `preserves_unstored(a_dest, f, as...)` - # to determine if a function preserves the stored values - # of the destination sparse array. - # The current code may be inefficient since it actually - # accesses an unstored element, which in the case of a - # sparse array of arrays can allocate an array. - # Sparse arrays could be expected to define a cheap - # unstored element allocator, for example - # `get_prototypical_unstored(a::AbstractArray)`. - I = first(eachindex(as...)) - preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) - if !preserves_unstored - # Doesn't preserve unstored values, loop over all elements. - @interface interface map_all!(f, a_dest, as...) - return a_dest + indices = if !preserves_unstored(f, a_dest, as...) + eachindex(a_dest) + elseif any(a -> a_dest !== a, as) + as = map(a -> Base.unalias(a_dest, a), as) + @interface interface zero!(a_dest) + eachstoredindex(as...) + else + eachstoredindex(a_dest) end - # First zero out the destination. - # TODO: Make this more nuanced, skip when possible, for - # example if the sparsity of the destination is a subset of - # the sparsity of the sources, i.e.: - # ```julia - # if eachstoredindex(as...) ∉ eachstoredindex(a_dest) - # zero!(a_dest) - # end - # ``` - # This is the safest thing to do in general, for example - # if the destination is dense but the sources are sparse. - @interface interface zero!(a_dest) - @interface interface map_stored!(f, a_dest, as...) + @interface interface map_indices!(indices, f, a_dest, as...) return a_dest end @@ -357,9 +359,7 @@ function sparse_mul!( β::Number=false; (mul!!)=(default_mul!!), ) - # TODO: Change to: `a_dest .*= β` - # once https://github.com/ITensor/SparseArraysBase.jl/issues/19 is fixed. - storedvalues(a_dest) .*= β + a_dest .*= β β′ = one(Bool) for I1 in eachstoredindex(a1) for I2 in eachstoredindex(a2) diff --git a/test/basics/test_sparsearraydok.jl b/test/basics/test_sparsearraydok.jl index 642f8c9..08cadf4 100644 --- a/test/basics/test_sparsearraydok.jl +++ b/test/basics/test_sparsearraydok.jl @@ -177,4 +177,42 @@ arrayts = (Array,) a[1, 2] = 12 @test sprint(show, "text/plain", a) == "$(summary(a)):\n ⋅ $(eltype(a)(12))\n ⋅ ⋅" end + + # Regression test for: + # https://github.com/ITensor/SparseArraysBase.jl/issues/19 + a = SparseArrayDOK{elt}(2, 2) + a[1, 1] = 1 + a .*= 2 + @test a == [2 0; 0 0] + @test storedlength(a) == 1 + + # Test aliasing behavior. + a = SparseArrayDOK{elt}(2, 2) + a[1, 1] = 11 + a[1, 2] = 12 + a[2, 2] = 22 + c1 = @view a[:, 1] + r1 = @view a[1, :] + r1 .= c1 + @test c1 == [11, 0] + @test storedlength(c1) == 1 + @test r1 == [11, 0] + @test storedlength(r1) == 2 + @test a == [11 0; 0 22] + @test storedlength(a) == 3 + + # Test aliasing behavior. + a = SparseArrayDOK{elt}(2, 2) + a[1, 1] = 11 + a[1, 2] = 12 + a[2, 2] = 22 + c1 = @view a[:, 1] + r1 = @view a[1, :] + c1 .= r1 + @test c1 == [11, 12] + @test storedlength(c1) == 2 + @test r1 == [11, 12] + @test storedlength(r1) == 2 + @test a == [11 12; 12 22] + @test storedlength(a) == 4 end