Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
45 changes: 26 additions & 19 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,25 @@
return SparseArrayDOK{T}(size...)
end

# map over a specified subset of indices of the inputs.
function map_indices! end

@interface interface::AbstractArrayInterface function map_indices!(
f, a_dest::AbstractArray, indices, as::AbstractArray...
)
for I in indices
a_dest[I] = f(map(a -> a[I], as)...)
end
return a_dest
end

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachstoredindex(as...)
a_dest[I] = f(map(a -> a[I], as)...)
end
@interface interface map_indices!(f, a_dest, eachstoredindex(as...), as...)
return a_dest
end

Expand All @@ -221,9 +231,7 @@
@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
for I in eachindex(as...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
@interface interface map_indices!(f, a_dest, eachindex(as...), as...)
return a_dest
end

Expand Down Expand Up @@ -261,19 +269,18 @@
@interface interface map_all!(f, a_dest, as...)
return a_dest
end
# First zero out the destination.
# TODO: Make this more nuanced, skip when possible, for
# example if the sparsity of the destination is a subset of
# the sparsity of the sources, i.e.:
# ```julia
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
# zero!(a_dest)
# end
# ```
# This is the safest thing to do in general, for example
# if the destination is dense but the sources are sparse.
@interface interface zero!(a_dest)
@interface interface map_stored!(f, a_dest, as...)
# Unalias the inputs from the destination
# to make sure the inputs aren't overwritten incorrectly.
# See: https://github.com/JuliaLang/julia/blob/v1.11.2/base/broadcast.jl#L935-L948
as = map(a -> a_dest === a ? a : Base.unalias(a_dest, a), as)
indices_stored = eachstoredindex(as...)
if eachstoredindex(a_dest) ⊈ indices_stored
# If not all indices being mapped over are stored in the destination,
# zero out the destination. An extreme example of this is when
# the sources are sparse but the destination is dense.
@interface interface zero!(a_dest)

Check warning on line 281 in src/abstractsparsearrayinterface.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractsparsearrayinterface.jl#L281

Added line #L281 was not covered by tests
end
@interface interface map_indices!(f, a_dest, indices_stored, as...)
return a_dest
end

Expand Down
Loading