Skip to content

Commit 3e0c3ce

Browse files
committed
Better code organization
1 parent 1e47d75 commit 3e0c3ce

File tree

1 file changed

+18
-8
lines changed
  • src/blocksparsearrayinterface

1 file changed

+18
-8
lines changed

src/blocksparsearrayinterface/map.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,27 @@ using BlockArrays: BlockRange, blockisequal
22
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
33
using GPUArraysCore: @allowscalar
44

5+
# Check if the block structures are the same.
6+
function same_block_structure(as::AbstractArray...)
7+
isempty(as) && return true
8+
return all(
9+
ntuple(ndims(first(as))) do dim
10+
ax = map(Base.Fix2(axes, dim), as)
11+
return blockisequal(ax...)
12+
end,
13+
)
14+
end
15+
16+
# Find the common stored blocks, assuming the block structures are the same.
17+
function union_eachblockstoredindex(as::AbstractArray...)
18+
return (map(eachblockstoredindex, (a_dest, a_srcs...))...)
19+
end
20+
521
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
622
# TODO: This assumes element types are numbers, generalize this logic.
723
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
824
Is = if f_preserves_zeros
9-
(map(eachblockstoredindex, (a_dest, a_srcs...))...)
25+
union_eachblockstoredindex(a_dest, a_srcs...)
1026
else
1127
BlockRange(a_dest)
1228
end
@@ -46,13 +62,7 @@ end
4662
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
4763
return a_dest
4864
end
49-
blockwise = all(
50-
ntuple(ndims(a_dest)) do dim
51-
ax = map(Base.Fix2(axes, dim), (a_dest, a_srcs...))
52-
return blockisequal(ax...)
53-
end,
54-
)
55-
if blockwise
65+
if same_block_structure(a_dest, a_srcs...)
5666
map_blockwise!(f, a_dest, a_srcs...)
5767
return a_dest
5868
end

0 commit comments

Comments
 (0)