Skip to content

Commit 96d3ac2

Browse files
committed
Rewrite map_stored_blocks
1 parent 307437d commit 96d3ac2

File tree

1 file changed

+13
-8
lines changed
  • src/blocksparsearrayinterface

1 file changed

+13
-8
lines changed

src/blocksparsearrayinterface/map.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
using BlockArrays: blocks, eachstoredindex, undef_blocks
12
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
23
using GPUArraysCore: @allowscalar
4+
using SparseArraysBase: SparseArrayDOK
35

46
# TODO: Rewrite this so that it takes the blocking structure
57
# made by combining the blocking of the axes (i.e. the blocking that
@@ -94,14 +96,17 @@ function map_zero_dim! end
9496
return a_dest
9597
end
9698

97-
# TODO: Decide what to do with these.
99+
# TODO: Do we need this function or can we just use `map`?
100+
# Probably it should be a special version of `map` where we
101+
# specify the function preserves zeros, i.e.
102+
# `map(f, a; preserves_zeros=true)` or `@preserves_zeros map(f, a)`.
98103
function map_stored_blocks(f, a::AbstractArray)
99-
bs = collect(eachblockstoredindex(a))
100-
ds = map(b -> f(@view(a[b])), bs)
101-
# TODO: Use `similartype` instead?
102-
a= BlockSparseArray{eltype(eltype(ds)),ndims(a),eltype(ds)}(undef, axes(a))
103-
for (b, d) in zip(bs, ds)
104-
a′[b] = d
104+
blocks_a = blocks(a)
105+
stored_indices = collect(eachstoredindex(a))
106+
stored_blocks = map(I -> f(blocks_a[I]), stored_indices)
107+
blocks_a= SparseArrayDOK{eltype(stored_blocks)}(undef_blocks, axes(a))
108+
for (I, b) in zip(stored_indices, stored_blocks)
109+
blocks_a′[I] = b
105110
end
106-
return a′
111+
return sparsemortar(blocks_a′, axes(a))
107112
end

0 commit comments

Comments
 (0)