@@ -2,11 +2,27 @@ using BlockArrays: BlockRange, blockisequal
22using DerivableInterfaces: @interface , AbstractArrayInterface, interface
33using GPUArraysCore: @allowscalar
44
5+ # Check if the block structures are the same.
6+ function same_block_structure (as:: AbstractArray... )
7+ isempty (as) && return true
8+ return all (
9+ ntuple (ndims (first (as))) do dim
10+ ax = map (Base. Fix2 (axes, dim), as)
11+ return blockisequal (ax... )
12+ end ,
13+ )
14+ end
15+
16+ # Find the common stored blocks, assuming the block structures are the same.
17+ function union_eachblockstoredindex (as:: AbstractArray... )
18+ return ∪ (map (eachblockstoredindex, (a_dest, a_srcs... ))... )
19+ end
20+
521function map_blockwise! (f, a_dest:: AbstractArray , a_srcs:: AbstractArray... )
622 # TODO : This assumes element types are numbers, generalize this logic.
723 f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
824 Is = if f_preserves_zeros
9- ∪ ( map (eachblockstoredindex, ( a_dest, a_srcs... )) ... )
25+ union_eachblockstoredindex ( a_dest, a_srcs... )
1026 else
1127 BlockRange (a_dest)
1228 end
4662 @interface interface map_zero_dim! (f, a_dest, a_srcs... )
4763 return a_dest
4864 end
49- blockwise = all (
50- ntuple (ndims (a_dest)) do dim
51- ax = map (Base. Fix2 (axes, dim), (a_dest, a_srcs... ))
52- return blockisequal (ax... )
53- end ,
54- )
55- if blockwise
65+ if same_block_structure (a_dest, a_srcs... )
5666 map_blockwise! (f, a_dest, a_srcs... )
5767 return a_dest
5868 end
0 commit comments