Skip to content

Commit 49cc24c

Browse files
committed
Start adding tests
1 parent d1fadaf commit 49cc24c

File tree

4 files changed

+38
-20
lines changed

4 files changed

+38
-20
lines changed

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ function Base.similar(
189189
end
190190

191191
function blocksparse_similar(a, elt::Type, axes::Tuple)
192-
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), axes)}(undef, axes)
192+
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
193+
undef, axes
194+
)
193195
end
194196

195197
# Needed by `BlockArrays` matrix multiplication interface

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,13 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i
125125

126126
# Represents the array of arrays of a `PermutedDimsArray`
127127
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
128-
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
129-
AbstractSparseArray{T,N}
128+
struct SparsePermutedDimsArrayBlocks{
129+
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
130+
} <: AbstractSparseArray{BlockType,N}
130131
array::Array
131132
end
132133
function blocksparse_blocks(a::PermutedDimsArray)
133-
return SparsePermutedDimsArrayBlocks(a)
134+
return SparsePermutedDimsArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
134135
end
135136
function Base.size(a::SparsePermutedDimsArrayBlocks)
136137
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
@@ -158,11 +159,12 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
158159

159160
# Represents the array of arrays of a `Transpose`
160161
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
161-
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
162+
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
163+
AbstractSparseMatrix{BlockType}
162164
array::Array
163165
end
164166
function blocksparse_blocks(a::Transpose)
165-
return SparseTransposeBlocks(a)
167+
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
166168
end
167169
function Base.size(a::SparseTransposeBlocks)
168170
return reverse(size(blocks(parent(a.array))))
@@ -192,11 +194,12 @@ end
192194

193195
# Represents the array of arrays of a `Adjoint`
194196
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
195-
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
197+
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
198+
AbstractSparseMatrix{BlockType}
196199
array::Array
197200
end
198201
function blocksparse_blocks(a::Adjoint)
199-
return SparseAdjointBlocks(a)
202+
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
200203
end
201204
function Base.size(a::SparseAdjointBlocks)
202205
return reverse(size(blocks(parent(a.array))))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
[deps]
22
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
33
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
4+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
5+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
46
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using BlockArrays:
1515
blocksizes,
1616
mortar
1717
using Compat: @compat
18+
using GPUArraysCore: @allowscalar
1819
using LinearAlgebra: Adjoint, mul!, norm
1920
using NDTensors.BlockSparseArrays:
2021
@view!,
@@ -28,27 +29,37 @@ using NDTensors.SparseArrayInterface: nstored
2829
using NDTensors.TensorAlgebra: contract
2930
using Test: @test, @test_broken, @test_throws, @testset
3031
include("TestBlockSparseArraysUtils.jl")
31-
@testset "BlockSparseArrays (eltype=$elt)" for elt in
32-
(Float32, Float64, ComplexF32, ComplexF64)
32+
33+
using NDTensors: NDTensors
34+
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
35+
using .NDTensorsTestUtils: devices_list, is_supported_eltype
36+
@testset "BlockSparseArrays (dev=$dev, eltype=$elt)" for dev in devices_list(copy(ARGS)),
37+
elt in (Float32, Float64, Complex{Float32}, Complex{Float64})
38+
39+
@show dev, elt
40+
41+
if !is_supported_eltype(dev, elt)
42+
continue
43+
end
3344
@testset "Broken" begin
3445
# TODO: Fix this and turn it into a proper test.
35-
a = BlockSparseArray{elt}([2, 3], [2, 3])
36-
a[Block(1, 1)] = randn(elt, 2, 2)
37-
a[Block(2, 2)] = randn(elt, 3, 3)
46+
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
47+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
48+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
3849
@test_broken a[:, 4]
3950

4051
# TODO: Fix this and turn it into a proper test.
41-
a = BlockSparseArray{elt}([2, 3], [2, 3])
42-
a[Block(1, 1)] = randn(elt, 2, 2)
43-
a[Block(2, 2)] = randn(elt, 3, 3)
52+
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
53+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
54+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
4455
@test_broken a[:, [2, 4]]
4556
@test_broken a[[3, 5], [2, 4]]
4657

4758
# TODO: Fix this and turn it into a proper test.
48-
a = BlockSparseArray{elt}([2, 3], [2, 3])
49-
a[Block(1, 1)] = randn(elt, 2, 2)
50-
a[Block(2, 2)] = randn(elt, 3, 3)
51-
@test a[2:4, 4] == Array(a)[2:4, 4]
59+
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
60+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
61+
a[Block(2, 2)] = dev(randn(elt, 3, 3))
62+
@allowscalar @test a[2:4, 4] == Array(a)[2:4, 4]
5263
@test_broken a[4, 2:4]
5364

5465
@test a[Block(1), :] isa BlockSparseArray{elt}

0 commit comments

Comments
 (0)