Skip to content

Commit 9424058

Browse files
committed
Reorganize map code, stricter block sparse map!
1 parent 6da5d4f commit 9424058

File tree

2 files changed

+127
-131
lines changed
  • src
    • abstractblocksparsearray
    • blocksparsearrayinterface

2 files changed

+127
-131
lines changed

src/abstractblocksparsearray/map.jl

Lines changed: 0 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,5 @@
11
using ArrayLayouts: LayoutArray
2-
using BlockArrays: blockisequal
3-
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
4-
using GPUArraysCore: @allowscalar
52
using LinearAlgebra: Adjoint, Transpose
6-
using SparseArraysBase: SparseArraysBase, SparseArrayStyle
7-
8-
# Returns `Vector{<:CartesianIndices}`
9-
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
10-
combined_axes = combine_axes(axes.(as)...)
11-
stored_blocked_cartesianindices_as = map(as) do a
12-
return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a))
13-
end
14-
return (stored_blocked_cartesianindices_as...)
15-
end
16-
17-
# This is used by `map` to get the output axes.
18-
# This is type piracy, try to avoid this, maybe requires defining `map`.
19-
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)
20-
21-
reblock(a) = a
22-
23-
# If the blocking of the slice doesn't match the blocking of the
24-
# parent array, reblock according to the blocking of the parent array.
25-
function reblock(
26-
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
27-
)
28-
# TODO: This relies on the behavior that slicing a block sparse
29-
# array with a UnitRange inherits the blocking of the underlying
30-
# block sparse array, we might change that default behavior
31-
# so this might become something like `@blocked parent(a)[...]`.
32-
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
33-
end
34-
35-
function reblock(
36-
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
37-
)
38-
return @view parent(a)[map(I -> I.array, parentindices(a))...]
39-
end
40-
41-
function reblock(
42-
a::SubArray{
43-
<:Any,
44-
<:Any,
45-
<:AbstractBlockSparseArray,
46-
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
47-
},
48-
)
49-
# Remove the blocking.
50-
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
51-
end
52-
53-
# `map!` specialized to zero-dimensional inputs.
54-
function map_zero_dim! end
55-
56-
@interface ::AbstractArrayInterface function map_zero_dim!(
57-
f, a_dest::AbstractArray, a_srcs::AbstractArray...
58-
)
59-
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
60-
return a_dest
61-
end
62-
63-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
64-
# TODO: Rewrite this so that it takes the blocking structure
65-
# made by combining the blocking of the axes (i.e. the blocking that
66-
# is used to determine `union_stored_blocked_cartesianindices(...)`).
67-
# `reblock` is a partial solution to that, but a bit ad-hoc.
68-
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
69-
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
70-
f, a_dest::AbstractArray, a_srcs::AbstractArray...
71-
)
72-
if isempty(a_srcs)
73-
# Broadcast expressions of the form `a .= 2`.
74-
@interface interface fill!(a_dest, f())
75-
return a_dest
76-
end
77-
if iszero(ndims(a_dest))
78-
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
79-
return a_dest
80-
end
81-
# TODO: This assumes element types are numbers, generalize this logic.
82-
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
83-
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
84-
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
85-
BI_dest = blockindexrange(a_dest, I)
86-
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
87-
# TODO: Investigate why this doesn't work:
88-
# block_dest = @view a_dest[_block(BI_dest)]
89-
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
90-
# TODO: Investigate why this doesn't work:
91-
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
92-
block_srcs = ntuple(length(a_srcs)) do i
93-
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
94-
end
95-
subblock_dest = @view block_dest[BI_dest.indices...]
96-
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
97-
I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest))))
98-
# If the function preserves zero values and all of the source blocks are zero,
99-
# the output block will be zero. In that case, if the block isn't stored yet,
100-
# don't do anything.
101-
if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest)
102-
continue
103-
end
104-
# TODO: Use `map!!` to handle immutable blocks.
105-
map!(f, subblock_dest, subblock_srcs...)
106-
# Replace the entire block, handles initializing new blocks
107-
# or if blocks are immutable.
108-
blocks(a_dest)[I_dest] = block_dest
109-
end
110-
return a_dest
111-
end
112-
113-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
114-
@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce(
115-
f, op, as::AbstractArray...; kwargs...
116-
)
117-
# TODO: Define an `init` value based on the element type.
118-
return @interface interface(blocks.(as)...) mapreduce(
119-
block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs...
120-
)
121-
end
122-
123-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
124-
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)
125-
# TODO: Just call `iszero(blocks(a))`?
126-
return @interface interface(blocks(a)) iszero(blocks(a))
127-
end
128-
129-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
130-
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)
131-
# TODO: Just call `isreal(blocks(a))`?
132-
return @interface interface(blocks(a)) isreal(blocks(a))
133-
end
1343

1354
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
1365
@interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...)

src/blocksparsearrayinterface/map.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,130 @@
1+
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
2+
using GPUArraysCore: @allowscalar
3+
4+
# TODO: Rewrite this so that it takes the blocking structure
5+
# made by combining the blocking of the axes (i.e. the blocking that
6+
# is used to determine `union_stored_blocked_cartesianindices(...)`).
7+
# `reblock` is a partial solution to that, but a bit ad-hoc.
8+
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
9+
@interface interface::AbstractBlockSparseArrayInterface function Base.map!(
10+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
11+
)
12+
if isempty(a_srcs)
13+
error("Can't call `map!` with zero source terms.")
14+
end
15+
if iszero(ndims(a_dest))
16+
@interface interface map_zero_dim!(f, a_dest, a_srcs...)
17+
return a_dest
18+
end
19+
# TODO: This assumes element types are numbers, generalize this logic.
20+
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
21+
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
22+
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
23+
BI_dest = blockindexrange(a_dest, I)
24+
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
25+
# TODO: Investigate why this doesn't work:
26+
# block_dest = @view a_dest[_block(BI_dest)]
27+
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...]
28+
# TODO: Investigate why this doesn't work:
29+
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
30+
block_srcs = ntuple(length(a_srcs)) do i
31+
return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
32+
end
33+
subblock_dest = @view block_dest[BI_dest.indices...]
34+
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
35+
I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest))))
36+
# If the function preserves zero values and all of the source blocks are zero,
37+
# the output block will be zero. In that case, if the block isn't stored yet,
38+
# don't do anything.
39+
if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest)
40+
continue
41+
end
42+
# TODO: Use `map!!` to handle immutable blocks.
43+
map!(f, subblock_dest, subblock_srcs...)
44+
# Replace the entire block, handles initializing new blocks
45+
# or if blocks are immutable.
46+
blocks(a_dest)[I_dest] = block_dest
47+
end
48+
return a_dest
49+
end
50+
51+
@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce(
52+
f, op, as::AbstractArray...; kwargs...
53+
)
54+
# TODO: Define an `init` value based on the element type.
55+
return @interface interface(blocks.(as)...) mapreduce(
56+
block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs...
57+
)
58+
end
59+
60+
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)
61+
# TODO: Just call `iszero(blocks(a))`?
62+
return @interface interface(blocks(a)) iszero(blocks(a))
63+
end
64+
65+
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)
66+
# TODO: Just call `isreal(blocks(a))`?
67+
return @interface interface(blocks(a)) isreal(blocks(a))
68+
end
69+
70+
# Helper functions for block sparse map.
71+
72+
# Returns `Vector{<:CartesianIndices}`
73+
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
74+
combined_axes = combine_axes(axes.(as)...)
75+
stored_blocked_cartesianindices_as = map(as) do a
76+
return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a))
77+
end
78+
return (stored_blocked_cartesianindices_as...)
79+
end
80+
81+
# This is used by `map` to get the output axes.
82+
# This is type piracy, try to avoid this, maybe requires defining `map`.
83+
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)
84+
85+
reblock(a) = a
86+
87+
# If the blocking of the slice doesn't match the blocking of the
88+
# parent array, reblock according to the blocking of the parent array.
89+
function reblock(
90+
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}}
91+
)
92+
# TODO: This relies on the behavior that slicing a block sparse
93+
# array with a UnitRange inherits the blocking of the underlying
94+
# block sparse array, we might change that default behavior
95+
# so this might become something like `@blocked parent(a)[...]`.
96+
return @view parent(a)[UnitRange{Int}.(parentindices(a))...]
97+
end
98+
99+
function reblock(
100+
a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}}
101+
)
102+
return @view parent(a)[map(I -> I.array, parentindices(a))...]
103+
end
104+
105+
function reblock(
106+
a::SubArray{
107+
<:Any,
108+
<:Any,
109+
<:AbstractBlockSparseArray,
110+
<:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}},
111+
},
112+
)
113+
# Remove the blocking.
114+
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
115+
end
116+
117+
# `map!` specialized to zero-dimensional inputs.
118+
function map_zero_dim! end
119+
120+
@interface ::AbstractArrayInterface function map_zero_dim!(
121+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
122+
)
123+
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
124+
return a_dest
125+
end
126+
127+
# TODO: Decide what to do with these.
1128
function map_stored_blocks(f, a::AbstractArray)
2129
# TODO: Implement this as:
3130
# ```julia

0 commit comments

Comments
 (0)