1+ using BlockArrays: BlockRange, blockisequal
12using DerivableInterfaces: @interface , AbstractArrayInterface, interface
23using GPUArraysCore: @allowscalar
34
5+ # Check if the block structures are the same.
6+ function same_block_structure (as:: AbstractArray... )
7+ isempty (as) && return true
8+ return all (
9+ ntuple (ndims (first (as))) do dim
10+ ax = map (Base. Fix2 (axes, dim), as)
11+ return blockisequal (ax... )
12+ end ,
13+ )
14+ end
15+
16+ # Find the common stored blocks, assuming the block structures are the same.
17+ function union_eachblockstoredindex (as:: AbstractArray... )
18+ return ∪ (map (eachblockstoredindex, as)... )
19+ end
20+
21+ function map_blockwise! (f, a_dest:: AbstractArray , a_srcs:: AbstractArray... )
22+ # TODO : This assumes element types are numbers, generalize this logic.
23+ f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
24+ Is = if f_preserves_zeros
25+ union_eachblockstoredindex (a_dest, a_srcs... )
26+ else
27+ BlockRange (a_dest)
28+ end
29+ for I in Is
30+ # TODO : Use:
31+ # block_dest = @view a_dest[I]
32+ # or:
33+ # block_dest = @view! a_dest[I]
34+ block_dest = blocks_maybe_single (a_dest)[Int .(Tuple (I))... ]
35+ # TODO : Use:
36+ # block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37+ block_srcs = map (a_srcs) do a_src
38+ return blocks_maybe_single (a_src)[Int .(Tuple (I))... ]
39+ end
40+ # TODO : Use `map!!` to handle immutable blocks.
41+ map! (f, block_dest, block_srcs... )
42+ # Replace the entire block, handles initializing new blocks
43+ # or if blocks are immutable.
44+ # TODO : Use `a_dest[I] = block_dest`.
45+ blocks (a_dest)[Int .(Tuple (I))... ] = block_dest
46+ end
47+ return a_dest
48+ end
49+
450# TODO : Rewrite this so that it takes the blocking structure
551# made by combining the blocking of the axes (i.e. the blocking that
652# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +62,10 @@ using GPUArraysCore: @allowscalar
1662 @interface interface map_zero_dim! (f, a_dest, a_srcs... )
1763 return a_dest
1864 end
65+ if same_block_structure (a_dest, a_srcs... )
66+ map_blockwise! (f, a_dest, a_srcs... )
67+ return a_dest
68+ end
1969 # TODO : This assumes element types are numbers, generalize this logic.
2070 f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
2171 a_dest, a_srcs = reblock (a_dest), reblock .(a_srcs)
0 commit comments