Skip to content

Commit c4f011c

Browse files
committed
Add tests
1 parent 83f904f commit c4f011c

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

src/sparsearrays.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function eachstoredindex(m::AbstractSparseMatrixCSC)
55
# TODO: This loses the compile time element type, is there a better lazy way?
66
return Iterators.map(CartesianIndex, zip(I, J))
77
end
8-
function eachstoredindex(a::Base.ReshapedArray{<:Any, <:Any, <: AbstractSparseMatrixCSC})
8+
function eachstoredindex(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC})
99
return @interface SparseArrayInterface() eachstoredindex(a)
1010
end
1111

@@ -17,16 +17,16 @@ function SparseArrays.SparseMatrixCSC{Tv, Ti}(m::AnyAbstractSparseMatrix) where
1717
return m′
1818
end
1919

20-
function SparseArrayDOK(a::Base.ReshapedArray{<:Any, <:Any, <: AbstractSparseMatrixCSC})
20+
function SparseArrayDOK(a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC})
2121
return SparseArrayDOK{eltype(a), ndims(a)}(a)
2222
end
2323
function SparseArrayDOK{T}(
24-
a::Base.ReshapedArray{<:Any, <:Any, <: AbstractSparseMatrixCSC}
24+
a::Base.ReshapedArray{<:Any, <:Any, <:AbstractSparseMatrixCSC}
2525
) where {T}
2626
return SparseArrayDOK{T, ndims(a)}(a)
2727
end
2828
function SparseArrayDOK{T, N}(
29-
a::Base.ReshapedArray{<:Any, N, <: AbstractSparseMatrixCSC}
29+
a::Base.ReshapedArray{<:Any, N, <:AbstractSparseMatrixCSC}
3030
) where {T, N}
3131
a′ = SparseArrayDOK{T, N}(undef, size(a))
3232
for I in eachstoredindex(a)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1111
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
1212
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1313
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
14+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516

1617
[sources]
@@ -29,4 +30,5 @@ SafeTestsets = "0.1.0"
2930
SparseArraysBase = "0.7.0"
3031
StableRNGs = "1.0.2"
3132
Suppressor = "0.2.8"
33+
TensorAlgebra = "0.6"
3234
Test = "<0.0.1, 1"

test/test_tensoralgebraext.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using SparseArrays: SparseMatrixCSC, findnz, nnz
2+
using SparseArraysBase: SparseMatrixDOK, sparsezeros
3+
using TensorAlgebra: contract, matricize
4+
using Test: @testset, @test
5+
6+
@testset "TensorAlgebraExt (eltype = $elt)" for elt in (Float32, ComplexF64)
7+
a = sparsezeros(elt, (2, 2, 2))
8+
a[1, 1, 1] = 1
9+
a[2, 1, 2] = 2
10+
11+
# matricize
12+
m = matricize(a, (1, 3), (2,))
13+
@test m isa SparseMatrixCSC{elt}
14+
@test nnz(m) == 2
15+
@test isstored(m, 1, 1)
16+
@test m[1, 1] elt(1)
17+
@test isstored(m, 4, 1)
18+
@test m[4, 1] elt(2)
19+
@test issetequal(eachstoredindex(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)])
20+
for I in setdiff(CartesianIndices(m), [CartesianIndex(1, 1), CartesianIndex(4, 1)])
21+
@test m[I] zero(elt)
22+
end
23+
24+
# contract
25+
b, l = contract(a, ("i", "j", "k"), a, ("j", "k", "l"))
26+
@test b isa SparseMatrixDOK{elt}
27+
@test storedlength(b) == 1
28+
@test only(eachstoredindex(b)) == CartesianIndex(1, 1)
29+
@test b[1, 1] elt(1)
30+
end

0 commit comments

Comments
 (0)