1+ using BlockArrays: BlockRange, blockisequal
12using DerivableInterfaces: @interface , AbstractArrayInterface, interface
23using GPUArraysCore: @allowscalar
34
5+ function map_blockwise! (f, a_dest:: AbstractArray , a_srcs:: AbstractArray... )
6+ # TODO : This assumes element types are numbers, generalize this logic.
7+ f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
8+ Is = if f_preserves_zeros
9+ ∪ (map (eachblockstoredindex, (a_dest, a_srcs... ))... )
10+ else
11+ BlockRange (a_dest)
12+ end
13+ for I in Is
14+ map! (f, view (a_dest, I), map (Base. Fix2 (view, I), a_srcs)... )
15+ end
16+ return a_dest
17+ end
18+
419# TODO : Rewrite this so that it takes the blocking structure
520# made by combining the blocking of the axes (i.e. the blocking that
621# is used to determine `union_stored_blocked_cartesianindices(...)`).
@@ -16,6 +31,16 @@ using GPUArraysCore: @allowscalar
1631 @interface interface map_zero_dim! (f, a_dest, a_srcs... )
1732 return a_dest
1833 end
34+ blockwise = all (
35+ ntuple (ndims (a_dest)) do dim
36+ ax = map (Base. Fix2 (axes, dim), (a_dest, a_srcs... ))
37+ return blockisequal (ax... )
38+ end ,
39+ )
40+ if blockwise
41+ map_blockwise! (f, a_dest, a_srcs... )
42+ return a_dest
43+ end
1944 # TODO : This assumes element types are numbers, generalize this logic.
2045 f_preserves_zeros = f (zero .(eltype .(a_srcs))... ) == zero (eltype (a_dest))
2146 a_dest, a_srcs = reblock (a_dest), reblock .(a_srcs)
0 commit comments