Skip to content

Commit 3318622

Browse files
committed
Make blockwise map friendlier to abstract block types
1 parent 3a199fb commit 3318622

File tree

7 files changed

+899
-845
lines changed

7 files changed

+899
-845
lines changed

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ function eachblockstoredindex(a::AbstractArray)
4242
return Block.(Tuple.(eachstoredindex(blocks(a))))
4343
end
4444

45+
function SparseArraysBase.isstored(
46+
a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}
47+
) where {N}
48+
return isstored(blocks(a), Int.(I)...)
49+
end
50+
function SparseArraysBase.isstored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
51+
return isstored(a, Tuple(I)...)
52+
end
53+
4554
using DiagonalArrays: diagindices
4655
# Block version of `DiagonalArrays.diagindices`.
4756
function blockdiagindices(a::AbstractArray)

src/blocksparsearrayinterface/map.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,38 @@ function union_eachblockstoredindex(as::AbstractArray...)
1818
return (map(eachblockstoredindex, as)...)
1919
end
2020

21+
# Get a view of a block assuming it is stored.
22+
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
23+
return blocks(a)[Int.(I)...]
24+
end
25+
function viewblock_stored(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
26+
return viewblock_stored(a, Tuple(I)...)
27+
end
28+
29+
using FillArrays: Zeros
30+
# Get a view of a block if it is stored, otherwise return a lazy zeros.
31+
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
32+
if isstored(a, I...)
33+
return viewblock_stored(a, I...)
34+
else
35+
block_ax = map((ax, i) -> eachblockaxis(ax)[Int(i)], axes(a), I)
36+
return Zeros{eltype(a)}(block_ax)
37+
end
38+
end
39+
function viewblock_or_zeros(a::AbstractArray{<:Any,N}, I::Block{N}) where {N}
40+
return viewblock_or_zeros(a, Tuple(I)...)
41+
end
42+
43+
function map_block!(f, a_dest::AbstractArray, I::Block, a_srcs::AbstractArray...)
44+
a_srcs_I = map(a_src -> viewblock_or_zeros(a_src, I), a_srcs)
45+
if isstored(a_dest, I)
46+
a_dest[I] .= f.(a_srcs_I...)
47+
else
48+
a_dest[I] = Broadcast.broadcast_preserving_zero_d(f, a_srcs_I...)
49+
end
50+
return a_dest
51+
end
52+
2153
function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
2254
# TODO: This assumes element types are numbers, generalize this logic.
2355
f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest))
@@ -27,22 +59,7 @@ function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)
2759
BlockRange(a_dest)
2860
end
2961
for I in Is
30-
# TODO: Use:
31-
# block_dest = @view a_dest[I]
32-
# or:
33-
# block_dest = @view! a_dest[I]
34-
block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...]
35-
# TODO: Use:
36-
# block_srcs = map(a_src -> @view(a_src[I]), a_srcs)
37-
block_srcs = map(a_srcs) do a_src
38-
return blocks_maybe_single(a_src)[Int.(Tuple(I))...]
39-
end
40-
# TODO: Use `map!!` to handle immutable blocks.
41-
map!(f, block_dest, block_srcs...)
42-
# Replace the entire block, handles initializing new blocks
43-
# or if blocks are immutable.
44-
# TODO: Use `a_dest[I] = block_dest`.
45-
blocks(a_dest)[Int.(Tuple(I))...] = block_dest
62+
map_block!(f, a_dest, I, a_srcs...)
4663
end
4764
return a_dest
4865
end

test/test_abstract_blocktype.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Adapt: adapt
2+
using BlockArrays: Block
3+
using BlockSparseArrays: BlockSparseMatrix, blockstoredlength
4+
using JLArrays: JLArray
5+
using SparseArraysBase: storedlength
6+
using Test: @test, @test_broken, @testset
7+
8+
elts = (Float32, Float64, ComplexF32)
9+
arrayts = (Array, JLArray)
10+
@testset "Abstract block type (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts,
11+
elt in elts
12+
13+
dev = adapt(arrayt)
14+
a = BlockSparseMatrix{elt,AbstractMatrix{elt}}(undef, [2, 3], [2, 3])
15+
@test sprint(show, MIME"text/plain"(), a) isa String
16+
@test iszero(storedlength(a))
17+
@test iszero(blockstoredlength(a))
18+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
19+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
20+
@test !iszero(a[Block(1, 1)])
21+
@test a[Block(1, 1)] isa arrayt{elt,2}
22+
@test !iszero(a[Block(2, 2)])
23+
@test a[Block(2, 2)] isa arrayt{elt,2}
24+
@test iszero(a[Block(2, 1)])
25+
@test a[Block(2, 1)] isa Matrix{elt}
26+
@test iszero(a[Block(1, 2)])
27+
@test a[Block(1, 2)] isa Matrix{elt}
28+
29+
b = copy(a)
30+
@test Array(b) Array(a)
31+
32+
b = a + a
33+
@test Array(b) Array(a) + Array(a)
34+
35+
b = 3a
36+
@test Array(b) 3Array(a)
37+
38+
if arrayt === Array
39+
b = a * a
40+
@test Array(b) Array(a) * Array(a)
41+
else
42+
@test_broken a * a
43+
end
44+
end

0 commit comments

Comments
 (0)