Skip to content

Commit 1764f0a

Browse files
committed
[SparseArrayInterface] [BlockSparseArrays] Sparse and block sparse dot products
1 parent 99513c3 commit 1764f0a

File tree

9 files changed

+141
-21
lines changed

9 files changed

+141
-21
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1-
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
1+
using ArrayLayouts: ArrayLayouts, DualLayout, MemoryLayout, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4-
using ..TypeParameterAccessors: similartype
4+
using ..TypeParameterAccessors: parenttype, similartype
55

66
function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
77
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
88
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
99
return BlockLayout{outer_layout,inner_layout}()
1010
end
11+
function ArrayLayouts.MemoryLayout(
12+
arraytype::Type{<:Adjoint{<:Any,<:AbstractBlockSparseVector}}
13+
)
14+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
15+
end
16+
function ArrayLayouts.MemoryLayout(
17+
arraytype::Type{<:Transpose{<:Any,<:AbstractBlockSparseVector}}
18+
)
19+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
20+
end
1121

1222
function Base.similar(
1323
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ function Base.similar(
158158
return similar(arraytype, eltype(arraytype), axes)
159159
end
160160

161+
# Fixes ambiguity error.
162+
function Base.similar(
163+
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
164+
)
165+
return similar(arraytype, eltype(arraytype), axes)
166+
end
167+
161168
# Needed by `BlockArrays` matrix multiplication interface
162169
# TODO: This fixes an ambiguity error with `OffsetArrays.jl`, but
163170
# is only appears to be needed in older versions of Julia like v1.6.
@@ -221,7 +228,7 @@ function Base.similar(
221228
end
222229

223230
# Fixes ambiguity error.
224-
function Base.similar(a::BlockSparseArrayLike{<:Any,0}, elt::Type, axes::Tuple{})
231+
function Base.similar(a::BlockSparseArrayLike, elt::Type, axes::Tuple{})
225232
return blocksparse_similar(a, elt, axes)
226233
end
227234

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,48 @@
1-
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
1+
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
22
using BlockArrays: BlockLayout
33
using ..SparseArrayInterface: SparseLayout
4-
using LinearAlgebra: mul!
4+
using LinearAlgebra: dot, mul!
55

66
function blocksparse_muladd!(
7-
α::Number, a1::AbstractMatrix, a2::AbstractMatrix, β::Number, a_dest::AbstractMatrix
7+
α::Number, a1::AbstractArray, a2::AbstractArray, β::Number, a_dest::AbstractArray
88
)
99
mul!(blocks(a_dest), blocks(a1), blocks(a2), α, β)
1010
return a_dest
1111
end
1212

13+
function blocksparse_matmul!(m::MulAdd)
14+
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
15+
blocksparse_muladd!(α, a1, a2, β, a_dest)
16+
return a_dest
17+
end
18+
1319
function ArrayLayouts.materialize!(
1420
m::MatMulMatAdd{
1521
<:BlockLayout{<:SparseLayout},
1622
<:BlockLayout{<:SparseLayout},
1723
<:BlockLayout{<:SparseLayout},
1824
},
1925
)
20-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
21-
blocksparse_muladd!(α, a1, a2, β, a_dest)
22-
return a_dest
26+
blocksparse_matmul!(m)
27+
return m.C
28+
end
29+
function ArrayLayouts.materialize!(
30+
m::MatMulVecAdd{
31+
<:BlockLayout{<:SparseLayout},
32+
<:BlockLayout{<:SparseLayout},
33+
<:BlockLayout{<:SparseLayout},
34+
},
35+
)
36+
blocksparse_matmul!(m)
37+
return m.C
38+
end
39+
40+
function blocksparse_dot(a1::AbstractArray, a2::AbstractArray)
41+
# TODO: Add a check that the blocking of `a1` and `a2` are
42+
# the same, or the same up to a reshape.
43+
return dot(blocks(a1), blocks(a2))
44+
end
45+
46+
function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}})
47+
return blocksparse_dot(d.A, d.B)
2348
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
184184

185185
# Represents the array of arrays of a `Transpose`
186186
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
187-
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
187+
struct SparseTransposeBlocks{T,BlockType<:AbstractArray{T},Array<:Transpose{T}} <:
188188
AbstractSparseMatrix{BlockType}
189189
array::Array
190190
end
@@ -219,7 +219,7 @@ end
219219

220220
# Represents the array of arrays of a `Adjoint`
221221
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
222-
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
222+
struct SparseAdjointBlocks{T,BlockType<:AbstractArray{T},Array<:Adjoint{T}} <:
223223
AbstractSparseMatrix{BlockType}
224224
array::Array
225225
end

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using BlockArrays:
1616
mortar
1717
using Compat: @compat
1818
using GPUArraysCore: @allowscalar
19-
using LinearAlgebra: Adjoint, mul!, norm
19+
using LinearAlgebra: Adjoint, dot, mul!, norm
2020
using NDTensors.BlockSparseArrays:
2121
@view!,
2222
BlockSparseArray,
@@ -575,7 +575,11 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
575575
a[b] = randn(elt, size(a[b]))
576576
end
577577
@test isassigned(a, 1, 1)
578+
@test isassigned(a, 1, 1, 1)
579+
@test !isassigned(a, 1, 1, 2)
578580
@test isassigned(a, 5, 7)
581+
@test isassigned(a, 5, 7, 1)
582+
@test !isassigned(a, 5, 7, 2)
579583
@test !isassigned(a, 0, 1)
580584
@test !isassigned(a, 5, 8)
581585
@test isassigned(a, Block(1), Block(1))
@@ -852,6 +856,16 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
852856
@allowscalar @test Array(a_dest) Array(a1′) * Array(a2′)
853857
end
854858
end
859+
@testset "Dot product" begin
860+
a1 = dev(BlockSparseArray{elt}([2, 3, 4]))
861+
a1[Block(1)] = dev(randn(elt, size(@view(a1[Block(1)]))))
862+
a1[Block(3)] = dev(randn(elt, size(@view(a1[Block(3)]))))
863+
a2 = dev(BlockSparseArray{elt}([2, 3, 4]))
864+
a2[Block(2)] = dev(randn(elt, size(@view(a1[Block(2)]))))
865+
a2[Block(3)] = dev(randn(elt, size(@view(a1[Block(3)]))))
866+
@test a1' * a2 Array(a1)' * Array(a2)
867+
@test dot(a1, a2) a1' * a2
868+
end
855869
@testset "TensorAlgebra" begin
856870
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
857871
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))
Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1-
using ArrayLayouts: ArrayLayouts, MatMulMatAdd
1+
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
22

33
function ArrayLayouts.MemoryLayout(arraytype::Type{<:SparseArrayLike})
44
return SparseLayout()
55
end
66

7-
function ArrayLayouts.materialize!(
8-
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
9-
)
7+
function sparse_matmul!(m::MulAdd)
108
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
119
sparse_mul!(a_dest, a1, a2, α, β)
1210
return a_dest
1311
end
12+
13+
function ArrayLayouts.materialize!(
14+
m::MatMulMatAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
15+
)
16+
sparse_matmul!(m)
17+
return m.C
18+
end
19+
function ArrayLayouts.materialize!(
20+
m::MatMulVecAdd{<:AbstractSparseLayout,<:AbstractSparseLayout,<:AbstractSparseLayout}
21+
)
22+
sparse_matmul!(m)
23+
return m.C
24+
end
25+
26+
function Base.copy(d::Dot{<:SparseLayout,<:SparseLayout})
27+
return sparse_dot(d.A, d.B)
28+
end

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra: mul!, norm
1+
using LinearAlgebra: dot, mul!, norm
22

33
sparse_norm(a::AbstractArray, p::Real=2) = norm(sparse_storage(a))
44

@@ -9,6 +9,14 @@ function mul_indices(I1::CartesianIndex{2}, I2::CartesianIndex{2})
99
return CartesianIndex(I1[1], I2[2])
1010
end
1111

12+
# TODO: Is this needed? Maybe when multiplying vectors?
13+
function mul_indices(I1::CartesianIndex{1}, I2::CartesianIndex{1})
14+
if I1 I2
15+
return nothing
16+
end
17+
return CartesianIndex(I1)
18+
end
19+
1220
function default_mul!!(
1321
a_dest::AbstractMatrix,
1422
a1::AbstractMatrix,
@@ -28,9 +36,9 @@ end
2836

2937
# a1 * a2 * α + a_dest * β
3038
function sparse_mul!(
31-
a_dest::AbstractMatrix,
32-
a1::AbstractMatrix,
33-
a2::AbstractMatrix,
39+
a_dest::AbstractArray,
40+
a1::AbstractArray,
41+
a2::AbstractArray,
3442
α::Number=true,
3543
β::Number=false;
3644
(mul!!)=(default_mul!!),
@@ -45,3 +53,25 @@ function sparse_mul!(
4553
end
4654
return a_dest
4755
end
56+
57+
function sparse_dot(a1::AbstractArray, a2::AbstractArray)
58+
# This assumes `a1` and `a2` have the same shape.
59+
# TODO: Generalize (Base supports dot products of
60+
# arrays with the same length but different sizes).
61+
@assert size(a1) == size(a2)
62+
dot_dest = zero(Base.promote_op(dot, eltype(a1), eltype(a2)))
63+
# TODO: First check if the number of stored elements (`nstored`, to be renamed
64+
# `stored_length`) is smaller in `a1` or `a2` and use whicheven one is smallar
65+
# as the outer loop.
66+
for I1 in stored_indices(a1)
67+
# TODO: Overload and use `Base.isstored(a, I) = I in stored_indices(a)` instead.
68+
# TODO: This assumes fast lookup of indices, which may not always be the case.
69+
# It could be better to loop over `stored_indices(a2)` and check that
70+
# `I1 == I2` instead (say using `mul_indices(I1, I2)`. We could have a trait
71+
# `HasFastIsStored(a::AbstractArray)` to choose between the two.
72+
if I1 in stored_indices(a2)
73+
dot_dest += dot(a1[I1], a2[I1])
74+
end
75+
end
76+
return dot_dest
77+
end

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ end
166166
function sparse_isassigned(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
167167
return sparse_isassigned(a, Tuple(I)...)
168168
end
169-
function sparse_isassigned(a::AbstractArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
169+
function sparse_isassigned(a::AbstractArray, I::Integer...)
170+
# Check trailing dimensions are one. This is needed in generic
171+
# AbstractArray show when `a isa AbstractVector`.
172+
all(d -> isone(I[d]), (ndims(a) + 1):length(I)) || return false
170173
return all(dim -> I[dim] axes(a, dim), 1:ndims(a))
171174
end
172175

NDTensors/src/lib/TensorAlgebra/src/LinearAlgebraExtensions/qr.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using ArrayLayouts: LayoutMatrix
12
using LinearAlgebra: LinearAlgebra, qr
23
using ..TensorAlgebra:
34
TensorAlgebra,
@@ -8,6 +9,8 @@ using ..TensorAlgebra:
89
fusedims,
910
splitdims
1011

12+
# TODO: Define as `tensor_qr`.
13+
# TODO: This look generic but doesn't work for `BlockSparseArrays`.
1114
function _qr(a::AbstractArray, biperm::BlockedPermutation{2})
1215
a_matricized = fusedims(a, biperm)
1316

@@ -38,6 +41,12 @@ function LinearAlgebra.qr(a::AbstractMatrix, biperm::BlockedPermutation{2})
3841
return _qr(a, biperm)
3942
end
4043

44+
# Fix ambiguity error with `ArrayLayouts`.
45+
function LinearAlgebra.qr(a::LayoutMatrix, biperm::BlockedPermutation{2})
46+
return _qr(a, biperm)
47+
end
48+
49+
# TODO: Define in terms of an inner function `_qr` or `tensor_qr`.
4150
function LinearAlgebra.qr(
4251
a::AbstractArray, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple
4352
)
@@ -50,3 +59,10 @@ function LinearAlgebra.qr(
5059
)
5160
return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r))
5261
end
62+
63+
# Fix ambiguity error with `ArrayLayouts`.
64+
function LinearAlgebra.qr(
65+
a::LayoutMatrix, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple
66+
)
67+
return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r))
68+
end

0 commit comments

Comments
 (0)