@@ -2,11 +2,27 @@ using BlockArrays: BlockRange, blockisequal
2
2
using DerivableInterfaces: @interface , AbstractArrayInterface, interface
3
3
using GPUArraysCore: @allowscalar
4
4
5
+ # Check if the block structures are the same.
6
+ function same_block_structure (as:: AbstractArray... )
7
+ isempty (as) && return true
8
+ return all (
9
+ ntuple (ndims (first (as))) do dim
10
+ ax = map (Base. Fix2 (axes, dim), as)
11
+ return blockisequal (ax... )
12
+ end ,
13
+ )
14
+ end
15
+
16
+ # Find the common stored blocks, assuming the block structures are the same.
17
+ function union_eachblockstoredindex (as:: AbstractArray... )
18
+ return ∪ (map (eachblockstoredindex, (a_dest, a_srcs... ))... )
19
+ end
20
+
5
21
function map_blockwise! (f, a_dest:: AbstractArray , a_srcs:: AbstractArray... )
6
22
# TODO : This assumes element types are numbers, generalize this logic.
7
23
f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
8
24
Is = if f_preserves_zeros
9
- ∪ ( map (eachblockstoredindex, ( a_dest, a_srcs... )) ... )
25
+ union_eachblockstoredindex ( a_dest, a_srcs... )
10
26
else
11
27
BlockRange (a_dest)
12
28
end
46
62
@interface interface map_zero_dim! (f, a_dest, a_srcs... )
47
63
return a_dest
48
64
end
49
- blockwise = all (
50
- ntuple (ndims (a_dest)) do dim
51
- ax = map (Base. Fix2 (axes, dim), (a_dest, a_srcs... ))
52
- return blockisequal (ax... )
53
- end ,
54
- )
55
- if blockwise
65
+ if same_block_structure (a_dest, a_srcs... )
56
66
map_blockwise! (f, a_dest, a_srcs... )
57
67
return a_dest
58
68
end
0 commit comments