Skip to content

Commit f73cffe

Browse files
authored
[BlockSparseArrays] Initial support for more general blocks, such as GPU blocks (#1560)
1 parent b41b9c3 commit f73cffe

File tree

13 files changed

+257
-120
lines changed

13 files changed

+257
-120
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module BlockSparseArraysAdaptExt
2+
using Adapt: Adapt, adapt
3+
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
4+
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
5+
end

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout, sub_materialize
12
using BlockArrays:
23
BlockArrays,
34
AbstractBlockArray,
@@ -537,6 +538,20 @@ function SparseArrayInterface.nstored(a::BlockView)
537538
return 0
538539
end
539540

541+
## # Allow more fine-grained control:
542+
## function ArrayLayouts.sub_materialize(layout, a::BlockView, ax)
543+
## return blocks(a.array)[Int.(a.block)...]
544+
## end
545+
## function ArrayLayouts.sub_materialize(layout, a::BlockView)
546+
## return sub_materialize(layout, a, axes(a))
547+
## end
548+
## function ArrayLayouts.sub_materialize(a::BlockView)
549+
## return sub_materialize(MemoryLayout(a), a)
550+
## end
551+
function ArrayLayouts.sub_materialize(a::BlockView)
552+
return blocks(a.array)[Int.(a.block)...]
553+
end
554+
540555
function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
541556
return view!(a, Tuple(index)...)
542557
end

src/BlockSparseArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
44
include("blocksparsearrayinterface/linearalgebra.jl")
55
include("blocksparsearrayinterface/blockzero.jl")
66
include("blocksparsearrayinterface/broadcast.jl")
7+
include("blocksparsearrayinterface/map.jl")
78
include("blocksparsearrayinterface/arraylayouts.jl")
89
include("blocksparsearrayinterface/views.jl")
910
include("abstractblocksparsearray/abstractblocksparsearray.jl")
@@ -20,4 +21,5 @@ include("blocksparsearray/blocksparsearray.jl")
2021
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
2122
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
2223
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
24+
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
2325
end

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end
1212

1313
Base.axes(::AbstractBlockSparseArray) = error("Not implemented")
1414

15-
blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented")
15+
# TODO: Add some logic to unwrapping wrapped arrays.
16+
# TODO: Decide what a good default is.
17+
blockstype(arraytype::Type{<:AbstractBlockSparseArray}) = SparseArrayDOK{AbstractArray}
18+
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T}}) where {T}
19+
return SparseArrayDOK{AbstractArray{T}}
20+
end
21+
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N}
22+
return SparseArrayDOK{AbstractArray{T,N},N}
23+
end
1624

1725
## # Specialized in order to fix ambiguity error with `BlockArrays`.
1826
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4+
using ..TypeParameterAccessors: similartype
45

56
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
67
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
@@ -9,15 +10,22 @@ function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
910
end
1011

1112
function Base.similar(
12-
::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}}, elt::Type, axes
13-
)
14-
return similar(BlockSparseArray{elt}, axes)
13+
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
14+
elt::Type,
15+
axes,
16+
) where {A,B}
17+
# TODO: Check that this equals `similartype(blocktype(B), elt, axes)`,
18+
# or maybe promote them?
19+
output_blocktype = similartype(blocktype(A), elt, axes)
20+
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
1521
end
1622

1723
# Materialize a SubArray view.
1824
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
19-
# TODO: Make more generic for GPU.
20-
a_dest = BlockSparseArray{eltype(a)}(axes)
25+
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
26+
# TODO: Use `similar`?
27+
blocktype_a = blocktype(parent(a))
28+
a_dest = BlockSparseArray{eltype(a),length(axes),blocktype_a}(axes)
2129
a_dest .= a
2230
return a_dest
2331
end
@@ -26,8 +34,7 @@ end
2634
function ArrayLayouts.sub_materialize(
2735
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
2836
)
29-
# TODO: Make more generic for GPU.
30-
a_dest = Array{eltype(a)}(undef, length.(axes))
37+
a_dest = blocktype(a)(undef, length.(axes))
3138
a_dest .= a
3239
return a_dest
3340
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using BlockArrays:
99
mortar,
1010
unblock
1111
using SplitApplyCombine: groupcount
12+
using ..TypeParameterAccessors: similartype
1213

1314
const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
1415
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
@@ -187,28 +188,29 @@ function Base.similar(
187188
return similar(arraytype, eltype(arraytype), axes)
188189
end
189190

191+
function blocksparse_similar(a, elt::Type, axes::Tuple)
192+
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
193+
undef, axes
194+
)
195+
end
196+
190197
# Needed by `BlockArrays` matrix multiplication interface
191198
# TODO: Define a `blocksparse_similar` function.
192199
function Base.similar(
193200
arraytype::Type{<:BlockSparseArrayLike},
194201
elt::Type,
195202
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
196203
)
197-
# TODO: Make generic for GPU, maybe using `blocktype`.
198-
# TODO: For non-block axes this should output `Array`.
199-
return BlockSparseArray{elt}(undef, axes)
204+
return blocksparse_similar(arraytype, elt, axes)
200205
end
201206

202207
# TODO: Define a `blocksparse_similar` function.
203208
function Base.similar(
204209
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
205210
)
206-
# TODO: Make generic for GPU, maybe using `blocktype`.
207-
# TODO: For non-block axes this should output `Array`.
208-
return BlockSparseArray{elt}(undef, axes)
211+
return blocksparse_similar(a, elt, axes)
209212
end
210213

211-
# TODO: Define a `blocksparse_similar` function.
212214
# Fixes ambiguity error with `BlockArrays`.
213215
function Base.similar(
214216
a::BlockSparseArrayLike,
@@ -217,21 +219,16 @@ function Base.similar(
217219
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
218220
},
219221
)
220-
# TODO: Make generic for GPU, maybe using `blocktype`.
221-
# TODO: For non-block axes this should output `Array`.
222-
return BlockSparseArray{elt}(undef, axes)
222+
return blocksparse_similar(a, elt, axes)
223223
end
224224

225-
# TODO: Define a `blocksparse_similar` function.
226225
# Fixes ambiguity error with `OffsetArrays`.
227226
function Base.similar(
228227
a::BlockSparseArrayLike,
229228
elt::Type,
230229
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
231230
)
232-
# TODO: Make generic for GPU, maybe using `blocktype`.
233-
# TODO: For non-block axes this should output `Array`.
234-
return BlockSparseArray{elt}(undef, axes)
231+
return blocksparse_similar(a, elt, axes)
235232
end
236233

237234
# Fixes ambiguity error with `BlockArrays`.
@@ -240,9 +237,7 @@ function Base.similar(
240237
elt::Type,
241238
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
242239
)
243-
# TODO: Make generic for GPU, maybe using `blocktype`.
244-
# TODO: For non-block axes this should output `Array`.
245-
return BlockSparseArray{elt}(undef, axes)
240+
return blocksparse_similar(a, elt, axes)
246241
end
247242

248243
# Fixes ambiguity errors with BlockArrays.
@@ -255,15 +250,12 @@ function Base.similar(
255250
Vararg{AbstractUnitRange{<:Integer}},
256251
},
257252
)
258-
return BlockSparseArray{elt}(undef, axes)
253+
return blocksparse_similar(a, elt, axes)
259254
end
260255

261-
# TODO: Define a `blocksparse_similar` function.
262256
# Fixes ambiguity error with `StaticArrays`.
263257
function Base.similar(
264258
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
265259
)
266-
# TODO: Make generic for GPU, maybe using `blocktype`.
267-
# TODO: For non-block axes this should output `Array`.
268-
return BlockSparseArray{elt}(undef, axes)
260+
return blocksparse_similar(a, elt, axes)
269261
end

src/blocksparsearray/blocksparsearray.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ function BlockSparseArray(
3535
return BlockSparseArray(Dictionary(block_indices, block_data), axes)
3636
end
3737

38+
function BlockSparseArray{T,N,A,Blocks}(
39+
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
40+
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
41+
return BlockSparseArray{T,N,A,Blocks,typeof(axes)}(blocks, axes)
42+
end
43+
44+
function BlockSparseArray{T,N,A}(
45+
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
46+
) where {T,N,A<:AbstractArray{T,N}}
47+
return BlockSparseArray{T,N,A,typeof(blocks)}(blocks, axes)
48+
end
49+
3850
function BlockSparseArray{T,N}(
3951
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
4052
) where {T,N}
@@ -49,9 +61,15 @@ function BlockSparseArray{T,N}(
4961
return BlockSparseArray{T,N}(blocks, axes)
5062
end
5163

64+
function BlockSparseArray{T,N,A}(
65+
axes::Tuple{Vararg{AbstractUnitRange,N}}
66+
) where {T,N,A<:AbstractArray{T,N}}
67+
blocks = default_blocks(A, axes)
68+
return BlockSparseArray{T,N,A}(blocks, axes)
69+
end
70+
5271
function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N}
53-
blocks = default_blocks(T, axes)
54-
return BlockSparseArray{T,N}(blocks, axes)
72+
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
5573
end
5674

5775
function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
@@ -74,6 +92,12 @@ function BlockSparseArray{T}(axes::Vararg{AbstractUnitRange}) where {T}
7492
return BlockSparseArray{T}(axes)
7593
end
7694

95+
function BlockSparseArray{T,N,A}(
96+
::UndefInitializer, dims::Tuple
97+
) where {T,N,A<:AbstractArray{T,N}}
98+
return BlockSparseArray{T,N,A}(dims)
99+
end
100+
77101
# undef
78102
function BlockSparseArray{T,N}(
79103
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
@@ -109,7 +133,23 @@ Base.axes(a::BlockSparseArray) = a.axes
109133
blocksparse_blocks(a::BlockSparseArray) = a.blocks
110134

111135
# TODO: Use `TypeParameterAccessors`.
112-
blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B
136+
function blockstype(
137+
arraytype::Type{<:BlockSparseArray{T,N,A,Blocks}}
138+
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
139+
return Blocks
140+
end
141+
function blockstype(
142+
arraytype::Type{<:BlockSparseArray{T,N,A}}
143+
) where {T,N,A<:AbstractArray{T,N}}
144+
return SparseArrayDOK{A,N}
145+
end
146+
function blockstype(arraytype::Type{<:BlockSparseArray{T,N}}) where {T,N}
147+
return SparseArrayDOK{AbstractArray{T,N},N}
148+
end
149+
function blockstype(arraytype::Type{<:BlockSparseArray{T}}) where {T}
150+
return SparseArrayDOK{AbstractArray{T}}
151+
end
152+
blockstype(arraytype::Type{<:BlockSparseArray}) = SparseArrayDOK{AbstractArray}
113153

114154
## # Base interface
115155
## function Base.similar(

src/blocksparsearray/defaults.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ function default_arraytype(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
2929
return Array{elt,length(axes)}
3030
end
3131

32-
function default_blocks(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
33-
block_data = Dictionary{Block{length(axes),Int},default_arraytype(elt, axes)}()
32+
function default_blocks(blocktype::Type, axes::Tuple{Vararg{AbstractUnitRange}})
33+
block_data = Dictionary{Block{length(axes),Int},blocktype}()
3434
return default_blocks(block_data, axes)
3535
end
3636

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,13 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i
125125

126126
# Represents the array of arrays of a `PermutedDimsArray`
127127
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
128-
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
129-
AbstractSparseArray{T,N}
128+
struct SparsePermutedDimsArrayBlocks{
129+
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
130+
} <: AbstractSparseArray{BlockType,N}
130131
array::Array
131132
end
132133
function blocksparse_blocks(a::PermutedDimsArray)
133-
return SparsePermutedDimsArrayBlocks(a)
134+
return SparsePermutedDimsArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
134135
end
135136
function Base.size(a::SparsePermutedDimsArrayBlocks)
136137
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
@@ -158,11 +159,12 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
158159

159160
# Represents the array of arrays of a `Transpose`
160161
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
161-
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
162+
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
163+
AbstractSparseMatrix{BlockType}
162164
array::Array
163165
end
164166
function blocksparse_blocks(a::Transpose)
165-
return SparseTransposeBlocks(a)
167+
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
166168
end
167169
function Base.size(a::SparseTransposeBlocks)
168170
return reverse(size(blocks(parent(a.array))))
@@ -192,11 +194,12 @@ end
192194

193195
# Represents the array of arrays of a `Adjoint`
194196
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
195-
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
197+
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
198+
AbstractSparseMatrix{BlockType}
196199
array::Array
197200
end
198201
function blocksparse_blocks(a::Adjoint)
199-
return SparseAdjointBlocks(a)
202+
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
200203
end
201204
function Base.size(a::SparseAdjointBlocks)
202205
return reverse(size(blocks(parent(a.array))))
@@ -230,9 +233,13 @@ end
230233

231234
# Represents the array of arrays of a `SubArray`
232235
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
233-
struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N}
236+
struct SparseSubArrayBlocks{T,N,BlockType<:AbstractArray{T,N},Array<:SubArray{T,N}} <:
237+
AbstractSparseArray{BlockType,N}
234238
array::Array
235239
end
240+
function blocksparse_blocks(a::SubArray)
241+
return SparseSubArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
242+
end
236243
# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
237244
function blockrange(a::SparseSubArrayBlocks)
238245
blockranges = blockrange.(axes(parent(a.array)), a.array.indices)
@@ -291,8 +298,10 @@ function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
291298
return map(I -> a[I], stored_indices(a))
292299
end
293300

294-
function blocksparse_blocks(a::SubArray)
295-
return SparseSubArrayBlocks(a)
301+
function SparseArrayInterface.getindex_zero_function(a::SparseSubArrayBlocks)
302+
# TODO: Base it off of `getindex_zero_function(blocks(parent(a.array))`, but replace the
303+
# axes with `axes(a.array)`.
304+
return BlockZero(axes(a.array))
296305
end
297306

298307
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)

src/blocksparsearrayinterface/blockzero.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ end
2929
function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I)
3030
# TODO: Make sure this works for sparse or block sparse blocks, immutable
3131
# blocks, diagonal blocks, etc.!
32-
return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false)
32+
blck_size = block_size(f.axes, Block(Tuple(I)))
33+
blck_type = similartype(arraytype, blck_size)
34+
return fill!(blck_type(undef, blck_size), false)
3335
end
3436

3537
# Fallback so that `SparseArray` with scalar elements works.

0 commit comments

Comments
 (0)