Skip to content

Commit 29104cf

Browse files
authored
[SparseArrayInterface] [BlockSparseArrays] Sparse and block sparse dot products (#1577)
* [SparseArrayInterface] [BlockSparseArrays] Sparse and block sparse dot products * [NDTensors] Bump to v0.3.62
1 parent 0e33591 commit 29104cf

File tree

5 files changed

+71
-12
lines changed

5 files changed

+71
-12
lines changed

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1-
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
1+
using ArrayLayouts: ArrayLayouts, DualLayout, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4-
using ..TypeParameterAccessors: similartype
4+
using ..TypeParameterAccessors: parenttype, similartype
55

66
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
77
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
88
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
99
return BlockLayout{outer_layout,inner_layout}()
1010
end
1111

12+
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
13+
function ArrayLayouts.MemoryLayout(
14+
arraytype::Type{<:Adjoint{<:Any,<:AbstractBlockSparseVector}}
15+
)
16+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
17+
end
18+
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
19+
function ArrayLayouts.MemoryLayout(
20+
arraytype::Type{<:Transpose{<:Any,<:AbstractBlockSparseVector}}
21+
)
22+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
23+
end
24+
1225
function Base.similar(
1326
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
1427
elt::Type,

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ function Base.similar(
158158
return similar(arraytype, eltype(arraytype), axes)
159159
end
160160

161+
# Fixes ambiguity error.
162+
function Base.similar(
163+
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
164+
)
165+
return similar(arraytype, eltype(arraytype), axes)
166+
end
167+
161168
# Needed by `BlockArrays` matrix multiplication interface
162169
# TODO: This fixes an ambiguity error with `OffsetArrays.jl`, but
163170
# is only appears to be needed in older versions of Julia like v1.6.
@@ -221,7 +228,7 @@ function Base.similar(
221228
end
222229

223230
# Fixes ambiguity error.
224-
function Base.similar(a::BlockSparseArrayLike{<:Any,0}, elt::Type, axes::Tuple{})
231+
function Base.similar(a::BlockSparseArrayLike, elt::Type, axes::Tuple{})
225232
return blocksparse_similar(a, elt, axes)
226233
end
227234

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,48 @@
1-
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
1+
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4-
using LinearAlgebra: mul!
4+
using LinearAlgebra: dot, mul!
55

66
function blocksparse_muladd!(
7-
α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix
7+
α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray
88
)
99
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
1010
return a_dest
1111
end
1212

13+
function blocksparse_matmul!(m::MulAdd)
14+
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
15+
blocksparse_muladd!(α, a1, a2, β, a_dest)
16+
return a_dest
17+
end
18+
1319
function ArrayLayouts.materialize!(
1420
m::MatMulMatAdd{
1521
<:BlockLayout{<:SparseLayout},
1622
<:BlockLayout{<:SparseLayout},
1723
<:BlockLayout{<:SparseLayout},
1824
},
1925
)
20-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
21-
blocksparse_muladd!(α, a1, a2, β, a_dest)
22-
return a_dest
26+
blocksparse_matmul!(m)
27+
return m.C
28+
end
29+
function ArrayLayouts.materialize!(
30+
m::MatMulVecAdd{
31+
<:BlockLayout{<:SparseLayout},
32+
<:BlockLayout{<:SparseLayout},
33+
<:BlockLayout{<:SparseLayout},
34+
},
35+
)
36+
blocksparse_matmul!(m)
37+
return m.C
38+
end
39+
40+
function blocksparse_dot(a1::AbstractArray, a2::AbstractArray)
41+
# TODO: Add a check that the blocking of `a1` and `a2` are
42+
# the same, or the same up to a reshape.
43+
return dot(blocks(a1), blocks(a2))
44+
end
45+
46+
function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}})
47+
return blocksparse_dot(d.A, d.B)
2348
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
184184

185185
# Represents the array of arrays of a `Transpose`
186186
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
187-
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
187+
struct SparseTransposeBlocks{T,BlockType<:AbstractArray{T},Array<:Transpose{T}} <:
188188
AbstractSparseMatrix{BlockType}
189189
array::Array
190190
end
@@ -219,7 +219,7 @@ end
219219

220220
# Represents the array of arrays of a `Adjoint`
221221
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
222-
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
222+
struct SparseAdjointBlocks{T,BlockType<:AbstractArray{T},Array<:Adjoint{T}} <:
223223
AbstractSparseMatrix{BlockType}
224224
array::Array
225225
end

test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using BlockArrays:
1616
mortar
1717
using Compat: @compat
1818
using GPUArraysCore: @allowscalar
19-
using LinearAlgebra: Adjoint, mul!, norm
19+
using LinearAlgebra: Adjoint, dot, mul!, norm
2020
using NDTensors.BlockSparseArrays:
2121
@view!,
2222
BlockSparseArray,
@@ -575,7 +575,11 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
575575
a[b] = randn(elt, size(a[b]))
576576
end
577577
@test isassigned(a, 1, 1)
578+
@test isassigned(a, 1, 1, 1)
579+
@test !isassigned(a, 1, 1, 2)
578580
@test isassigned(a, 5, 7)
581+
@test isassigned(a, 5, 7, 1)
582+
@test !isassigned(a, 5, 7, 2)
579583
@test !isassigned(a, 0, 1)
580584
@test !isassigned(a, 5, 8)
581585
@test isassigned(a, Block(1), Block(1))
@@ -852,6 +856,16 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
852856
@allowscalar @test Array(a_dest) Array(a1′) * Array(a2′)
853857
end
854858
end
859+
@testset "Dot product" begin
860+
a1 = dev(BlockSparseArray{elt}([2, 3, 4]))
861+
a1[Block(1)] = dev(randn(elt, size(@view(a1[Block(1)]))))
862+
a1[Block(3)] = dev(randn(elt, size(@view(a1[Block(3)]))))
863+
a2 = dev(BlockSparseArray{elt}([2, 3, 4]))
864+
a2[Block(2)] = dev(randn(elt, size(@view(a1[Block(2)]))))
865+
a2[Block(3)] = dev(randn(elt, size(@view(a1[Block(3)]))))
866+
@test a1' * a2 Array(a1)' * Array(a2)
867+
@test dot(a1, a2) a1' * a2
868+
end
855869
@testset "TensorAlgebra" begin
856870
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
857871
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))

0 commit comments

Comments
 (0)