Skip to content

Commit 0ae32a5

Browse files
committed
Fix issues with in-place map/broadcast
1 parent a57e030 commit 0ae32a5

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/abstractsparsearrayinterface.jl

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,25 @@ end
203203
return SparseArrayDOK{T}(size...)
204204
end
205205

206+
# map over a specified subset of indices of the inputs.
207+
function map_indices! end
208+
209+
@interface interface::AbstractArrayInterface function map_indices!(
210+
f, a_dest::AbstractArray, indices, as::AbstractArray...
211+
)
212+
for I in indices
213+
a_dest[I] = f(map(a -> a[I], as)...)
214+
end
215+
return a_dest
216+
end
217+
206218
# Only map the stored values of the inputs.
207219
function map_stored! end
208220

209221
@interface interface::AbstractArrayInterface function map_stored!(
210222
f, a_dest::AbstractArray, as::AbstractArray...
211223
)
212-
for I in eachstoredindex(as...)
213-
a_dest[I] = f(map(a -> a[I], as)...)
214-
end
224+
@interface interface map_indices!(f, a_dest, eachstoredindex(as...), as...)
215225
return a_dest
216226
end
217227

@@ -221,9 +231,7 @@ function map_all! end
221231
@interface interface::AbstractArrayInterface function map_all!(
222232
f, a_dest::AbstractArray, as::AbstractArray...
223233
)
224-
for I in eachindex(as...)
225-
a_dest[I] = map(f, map(a -> a[I], as)...)
226-
end
234+
@interface interface map_indices!(f, a_dest, eachindex(as...), as...)
227235
return a_dest
228236
end
229237

@@ -261,19 +269,18 @@ end
261269
@interface interface map_all!(f, a_dest, as...)
262270
return a_dest
263271
end
264-
# First zero out the destination.
265-
# TODO: Make this more nuanced, skip when possible, for
266-
# example if the sparsity of the destination is a subset of
267-
# the sparsity of the sources, i.e.:
268-
# ```julia
269-
# if eachstoredindex(as...) ∉ eachstoredindex(a_dest)
270-
# zero!(a_dest)
271-
# end
272-
# ```
273-
# This is the safest thing to do in general, for example
274-
# if the destination is dense but the sources are sparse.
275-
@interface interface zero!(a_dest)
276-
@interface interface map_stored!(f, a_dest, as...)
272+
# Unalias the inputs from the destination
273+
# to make sure the inputs aren't overwritten incorrectly.
274+
# See: https://github.com/JuliaLang/julia/blob/v1.11.2/base/broadcast.jl#L935-L948
275+
as = map(a -> a_dest === a ? a : Base.unalias(a_dest, a), as)
276+
indices_stored = eachstoredindex(as...)
277+
if eachstoredindex(a_dest) indices_stored
278+
# If not all indices being mapped over are stored in the destination,
279+
# zero out the destination. An extreme example of this is when
280+
# the sources are sparse but the destination is dense.
281+
@interface interface zero!(a_dest)
282+
end
283+
@interface interface map_indices!(f, a_dest, indices_stored, as...)
277284
return a_dest
278285
end
279286

0 commit comments

Comments
 (0)