Skip to content

Commit a48f7e7

Browse files
committed
[SparseArrayInterface] Add tests for dot
1 parent 1764f0a commit a48f7e7

File tree

6 files changed

+43
-8
lines changed

6 files changed

+43
-8
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@ function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
88
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
99
return BlockLayout{outer_layout,inner_layout}()
1010
end
11+
12+
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
1113
function ArrayLayouts.MemoryLayout(
1214
arraytype::Type{<:Adjoint{<:Any,<:AbstractBlockSparseVector}}
1315
)
1416
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
1517
end
18+
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
1619
function ArrayLayouts.MemoryLayout(
1720
arraytype::Type{<:Transpose{<:Any,<:AbstractBlockSparseVector}}
1821
)

NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/arraylayouts.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1-
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
1+
using ArrayLayouts: ArrayLayouts, Dot, DualLayout, MatMulMatAdd, MatMulVecAdd, MulAdd
2+
using LinearAlgebra: Adjoint, Transpose
3+
using ..TypeParameterAccessors: parenttype
24

35
function ArrayLayouts.MemoryLayout(arraytype::Type{<:SparseArrayLike})
46
return SparseLayout()
57
end
68

9+
# TODO: Generalize to `SparseVectorLike`/`AnySparseVector`.
10+
function ArrayLayouts.MemoryLayout(arraytype::Type{<:Adjoint{<:Any,<:AbstractSparseVector}})
11+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
12+
end
13+
# TODO: Generalize to `SparseVectorLike`/`AnySparseVector`.
14+
function ArrayLayouts.MemoryLayout(
15+
arraytype::Type{<:Transpose{<:Any,<:AbstractSparseVector}}
16+
)
17+
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
18+
end
19+
720
function sparse_matmul!(m::MulAdd)
821
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
922
sparse_mul!(a_dest, a1, a2, α, β)

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ function sparse_mul!(
5555
end
5656

5757
function sparse_dot(a1::AbstractArray, a2::AbstractArray)
58-
# This assumes `a1` and `a2` have the same shape.
58+
# This requires that `a1` and `a2` have the same shape.
5959
# TODO: Generalize (Base supports dot products of
6060
# arrays with the same length but different sizes).
61-
@assert size(a1) == size(a2)
61+
size(a1) == size(a2) ||
62+
throw(DimensionMismatch("Sizes $(size(a1)) and $(size(a2)) don't match."))
6263
dot_dest = zero(Base.promote_op(dot, eltype(a1), eltype(a2)))
6364
# TODO: First check if the number of stored elements (`nstored`, to be renamed
6465
# `stored_length`) is smaller in `a1` or `a2` and use whicheven one is smallar

NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ function LinearAlgebra.mul!(
3232
return a_dest
3333
end
3434

35+
function LinearAlgebra.dot(a1::SparseArray, a2::SparseArray)
36+
return SparseArrayInterface.sparse_dot(a1, a2)
37+
end
38+
3539
# AbstractArray interface
3640
Base.size(a::SparseArray) = a.dims
3741
function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}})

NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@eval module $(gensym())
2-
using LinearAlgebra: mul!, norm
2+
using LinearAlgebra: dot, mul!, norm
33
using NDTensors.SparseArrayInterface: SparseArrayInterface
44
include("SparseArrayInterfaceTestUtils/SparseArrayInterfaceTestUtils.jl")
55
using .SparseArrayInterfaceTestUtils.AbstractSparseArrays: AbstractSparseArrays
@@ -299,6 +299,18 @@ using Test: @test, @testset
299299
@test a_dest isa SparseArray{elt}
300300
@test SparseArrayInterface.nstored(a_dest) == 2
301301

302+
# Dot product
303+
a1 = SparseArray{elt}(4)
304+
a1[1] = randn()
305+
a1[3] = randn()
306+
a2 = SparseArray{elt}(4)
307+
a2[2] = randn()
308+
a2[3] = randn()
309+
a_dest = a1' * a2
310+
@test a_dest isa elt
311+
@test a_dest Array(a1)' * Array(a2)
312+
@test a_dest dot(a1, a2)
313+
302314
# In-place matrix multiplication
303315
a1 = SparseArray{elt}(2, 3)
304316
a1[1, 2] = 12

src/lib/ITensorsNamedDimsArraysExt/src/to_nameddimsarray.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,19 @@ function to_nameddimsarray(x::DiagTensor)
3434
return named(DiagonalArray(data(x), size(x)), name.(inds(x)))
3535
end
3636

37-
using ..NDTensors: BlockSparseTensor
37+
using ..NDTensors.BlockSparseArrays.BlockArrays: BlockArrays, blockedrange
38+
using ..NDTensors: BlockSparseTensor, array, blockdim, datatype, nblocks, nzblocks
3839
using ..NDTensors.BlockSparseArrays: BlockSparseArray
40+
using ..NDTensors.TypeParameterAccessors: set_ndims
3941
# TODO: Delete once `BlockSparse` is removed.
4042
function to_nameddimsarray(x::BlockSparseTensor)
41-
blockinds = map(i -> [blockdim(i, b) for b in 1:nblocks(i)], inds(x))
43+
blockinds = map(i -> blockedrange([blockdim(i, b) for b in 1:nblocks(i)]), inds(x))
4244
blocktype = set_ndims(datatype(x), ndims(x))
4345
# TODO: Make a simpler constructor:
4446
# BlockSparseArray(blocktype, blockinds)
45-
arraystorage = BlockSparseArray{eltype(x),ndims(x),blocktype}(blockinds)
47+
arraystorage = BlockSparseArray{eltype(x),ndims(x),blocktype}(undef, blockinds)
4648
for b in nzblocks(x)
47-
arraystorage[BlockArrays.Block(Tuple(b)...)] = x[b]
49+
arraystorage[BlockArrays.Block(Int.(Tuple(b))...)] = array(x[b])
4850
end
4951
return named(arraystorage, name.(inds(x)))
5052
end

0 commit comments

Comments
 (0)