Skip to content

Commit f64d7fd

Browse files
committed
[BlockSparseArrays] GPU support
1 parent 8bb156a commit f64d7fd

File tree

5 files changed

+77
-31
lines changed

5 files changed

+77
-31
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end
1212

1313
Base.axes(::AbstractBlockSparseArray) = error("Not implemented")
1414

15-
blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented")
15+
# TODO: Add some logic to unwrapping wrapped arrays.
16+
# TODO: Decide what a good default is.
17+
blockstype(arraytype::Type{<:AbstractBlockSparseArray}) = SparseArrayDOK{AbstractArray}
18+
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T}}) where {T}
19+
return SparseArrayDOK{AbstractArray{T}}
20+
end
21+
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N}
22+
return SparseArrayDOK{AbstractArray{T,N},N}
23+
end
1624

1725
## # Specialized in order to fix ambiguity error with `BlockArrays`.
1826
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4+
# TODO: Move to `NDTensors.TypeParameterAccessors`.
5+
using ..NDTensors: similartype
46

57
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
68
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
@@ -9,9 +11,14 @@ function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
911
end
1012

1113
function Base.similar(
12-
::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}}, elt::Type, axes
13-
)
14-
return similar(BlockSparseArray{elt}, axes)
14+
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
15+
elt::Type,
16+
axes,
17+
) where {A,B}
18+
# TODO: Check that this equals `similartype(blocktype(B), elt, axes)`,
19+
# or maybe promote them?
20+
output_blocktype = similartype(blocktype(A), elt, axes)
21+
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
1522
end
1623

1724
# Materialize a SubArray view.

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using BlockArrays:
99
mortar,
1010
unblock
1111
using SplitApplyCombine: groupcount
12+
# TODO: Move to `NDTensors.TypeParameterAccessors`.
13+
using ..NDTensors: similartype
1214

1315
const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
1416
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
@@ -187,28 +189,27 @@ function Base.similar(
187189
return similar(arraytype, eltype(arraytype), axes)
188190
end
189191

192+
function blocksparse_similar(a, elt::Type, axes::Tuple)
193+
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), axes)}(undef, axes)
194+
end
195+
190196
# Needed by `BlockArrays` matrix multiplication interface
191197
# TODO: Define a `blocksparse_similar` function.
192198
function Base.similar(
193199
arraytype::Type{<:BlockSparseArrayLike},
194200
elt::Type,
195201
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
196202
)
197-
# TODO: Make generic for GPU, maybe using `blocktype`.
198-
# TODO: For non-block axes this should output `Array`.
199-
return BlockSparseArray{elt}(undef, axes)
203+
return blocksparse_similar(arraytype, elt, axes)
200204
end
201205

202206
# TODO: Define a `blocksparse_similar` function.
203207
function Base.similar(
204208
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
205209
)
206-
# TODO: Make generic for GPU, maybe using `blocktype`.
207-
# TODO: For non-block axes this should output `Array`.
208-
return BlockSparseArray{elt}(undef, axes)
210+
return blocksparse_similar(a, elt, axes)
209211
end
210212

211-
# TODO: Define a `blocksparse_similar` function.
212213
# Fixes ambiguity error with `BlockArrays`.
213214
function Base.similar(
214215
a::BlockSparseArrayLike,
@@ -217,21 +218,16 @@ function Base.similar(
217218
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
218219
},
219220
)
220-
# TODO: Make generic for GPU, maybe using `blocktype`.
221-
# TODO: For non-block axes this should output `Array`.
222-
return BlockSparseArray{elt}(undef, axes)
221+
return blocksparse_similar(a, elt, axes)
223222
end
224223

225-
# TODO: Define a `blocksparse_similar` function.
226224
# Fixes ambiguity error with `OffsetArrays`.
227225
function Base.similar(
228226
a::BlockSparseArrayLike,
229227
elt::Type,
230228
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
231229
)
232-
# TODO: Make generic for GPU, maybe using `blocktype`.
233-
# TODO: For non-block axes this should output `Array`.
234-
return BlockSparseArray{elt}(undef, axes)
230+
return blocksparse_similar(a, elt, axes)
235231
end
236232

237233
# Fixes ambiguity error with `BlockArrays`.
@@ -240,9 +236,7 @@ function Base.similar(
240236
elt::Type,
241237
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
242238
)
243-
# TODO: Make generic for GPU, maybe using `blocktype`.
244-
# TODO: For non-block axes this should output `Array`.
245-
return BlockSparseArray{elt}(undef, axes)
239+
return blocksparse_similar(a, elt, axes)
246240
end
247241

248242
# Fixes ambiguity errors with BlockArrays.
@@ -255,15 +249,12 @@ function Base.similar(
255249
Vararg{AbstractUnitRange{<:Integer}},
256250
},
257251
)
258-
return BlockSparseArray{elt}(undef, axes)
252+
return blocksparse_similar(a, elt, axes)
259253
end
260254

261-
# TODO: Define a `blocksparse_similar` function.
262255
# Fixes ambiguity error with `StaticArrays`.
263256
function Base.similar(
264257
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
265258
)
266-
# TODO: Make generic for GPU, maybe using `blocktype`.
267-
# TODO: For non-block axes this should output `Array`.
268-
return BlockSparseArray{elt}(undef, axes)
259+
return blocksparse_similar(a, elt, axes)
269260
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ function BlockSparseArray(
3535
return BlockSparseArray(Dictionary(block_indices, block_data), axes)
3636
end
3737

38+
function BlockSparseArray{T,N,A,Blocks}(
39+
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
40+
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
41+
return BlockSparseArray{T,N,A,Blocks,typeof(axes)}(blocks, axes)
42+
end
43+
44+
function BlockSparseArray{T,N,A}(
45+
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
46+
) where {T,N,A<:AbstractArray{T,N}}
47+
return BlockSparseArray{T,N,A,typeof(blocks)}(blocks, axes)
48+
end
49+
3850
function BlockSparseArray{T,N}(
3951
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
4052
) where {T,N}
@@ -49,9 +61,15 @@ function BlockSparseArray{T,N}(
4961
return BlockSparseArray{T,N}(blocks, axes)
5062
end
5163

64+
function BlockSparseArray{T,N,A}(
65+
axes::Tuple{Vararg{AbstractUnitRange,N}}
66+
) where {T,N,A<:AbstractArray{T,N}}
67+
blocks = default_blocks(A, axes)
68+
return BlockSparseArray{T,N,A}(blocks, axes)
69+
end
70+
5271
function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N}
53-
blocks = default_blocks(T, axes)
54-
return BlockSparseArray{T,N}(blocks, axes)
72+
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
5573
end
5674

5775
function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
@@ -74,6 +92,12 @@ function BlockSparseArray{T}(axes::Vararg{AbstractUnitRange}) where {T}
7492
return BlockSparseArray{T}(axes)
7593
end
7694

95+
function BlockSparseArray{T,N,A}(
96+
::UndefInitializer, dims::Tuple
97+
) where {T,N,A<:AbstractArray{T,N}}
98+
return BlockSparseArray{T,N,A}(dims)
99+
end
100+
77101
# undef
78102
function BlockSparseArray{T,N}(
79103
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
@@ -109,7 +133,23 @@ Base.axes(a::BlockSparseArray) = a.axes
109133
blocksparse_blocks(a::BlockSparseArray) = a.blocks
110134

111135
# TODO: Use `TypeParameterAccessors`.
112-
blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B
136+
function blockstype(
137+
arraytype::Type{<:BlockSparseArray{T,N,A,Blocks}}
138+
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
139+
return Blocks
140+
end
141+
function blockstype(
142+
arraytype::Type{<:BlockSparseArray{T,N,A}}
143+
) where {T,N,A<:AbstractArray{T,N}}
144+
return SparseArrayDOK{A,N}
145+
end
146+
function blockstype(arraytype::Type{<:BlockSparseArray{T,N}}) where {T,N}
147+
return SparseArrayDOK{AbstractArray{T,N},N}
148+
end
149+
function blockstype(arraytype::Type{<:BlockSparseArray{T}}) where {T}
150+
return SparseArrayDOK{AbstractArray{T}}
151+
end
152+
blockstype(arraytype::Type{<:BlockSparseArray}) = SparseArrayDOK{AbstractArray}
113153

114154
## # Base interface
115155
## function Base.similar(

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/defaults.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ function default_arraytype(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
2929
return Array{elt,length(axes)}
3030
end
3131

32-
function default_blocks(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
33-
block_data = Dictionary{Block{length(axes),Int},default_arraytype(elt, axes)}()
32+
function default_blocks(blocktype::Type, axes::Tuple{Vararg{AbstractUnitRange}})
33+
block_data = Dictionary{Block{length(axes),Int},blocktype}()
3434
return default_blocks(block_data, axes)
3535
end
3636

0 commit comments

Comments
 (0)