Skip to content

Commit 74c992a

Browse files
committed
Increase test coverage for ArrayInterface
Tests pass for ArrayInterface now (locally). Also added missing tests for matrix colors.
1 parent c28d427 commit 74c992a

File tree

5 files changed

+101
-73
lines changed

5 files changed

+101
-73
lines changed

ArrayInterfaceCore/src/ArrayInterfaceCore.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -346,22 +346,10 @@ The color vector for dense matrix and triangular matrix is simply
346346
function matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
347347
eachindex(1:Base.size(A, 2)) # Vector Base.size matches number of rows
348348
end
349-
350-
function _cycle(repetend, len)
351-
repeat(repetend, div(len, length(repetend)) + 1)[1:len]
352-
end
353-
354-
function matrix_colors(A::Diagonal)
355-
fill(1, Base.size(A, 2))
356-
end
357-
358-
function matrix_colors(A::Bidiagonal)
359-
_cycle(1:2, Base.size(A, 2))
360-
end
361-
362-
function matrix_colors(A::Union{Tridiagonal,SymTridiagonal})
363-
_cycle(1:3, Base.size(A, 2))
364-
end
349+
matrix_colors(A::Diagonal) = fill(1, Base.size(A, 2))
350+
matrix_colors(A::Bidiagonal) = _cycle(1:2, Base.size(A, 2))
351+
matrix_colors(A::Union{Tridiagonal,SymTridiagonal}) = _cycle(1:3, Base.size(A, 2))
352+
_cycle(repetend, len) = repeat(repetend, div(len, length(repetend)) + 1)[1:len]
365353

366354
"""
367355
lu_instance(A) -> lu_factorization_instance

ArrayInterfaceCore/test/array_index.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,3 @@ STri=SymTridiagonal([1,2,3,4],[5,6,7])
5151
rowind,colind=findstructralnz(STri)
5252
@test [STri[rowind[i],colind[i]] for i in 1:length(rowind)]==[1,2,3,4,5,6,7,5,6,7]
5353

54-

ArrayInterfaceCore/test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ include("setup.jl")
1717

1818
@test zeromatrix(rand(4,4,4)) == zeros(4*4*4,4*4*4)
1919

20+
@testset "matrix colors" begin
21+
@test ArrayInterfaceCore.fast_matrix_colors(1) == false
22+
@test ArrayInterfaceCore.fast_matrix_colors(Diagonal{Int,Vector{Int}})
23+
24+
@test ArrayInterfaceCore.matrix_colors(Diagonal([1,2,3,4])) == [1, 1, 1, 1]
25+
@test ArrayInterfaceCore.matrix_colors(Bidiagonal([1,2,3,4], [7,8,9], :U)) == [1, 2, 1, 2]
26+
@test ArrayInterfaceCore.matrix_colors(Tridiagonal([1,2,3],[1,2,3,4],[4,5,6])) == [1, 2, 3, 1]
27+
@test ArrayInterfaceCore.matrix_colors(SymTridiagonal([1,2,3,4],[5,6,7])) == [1, 2, 3, 1]
28+
@test ArrayInterfaceCore.matrix_colors(rand(4,4)) == Base.OneTo(4)
29+
end
30+
2031
@testset "parent_type" begin
2132
x = ones(4, 4)
2233
@test parent_type(view(x, 1:2, 1:2)) <: typeof(x)

src/ArrayInterface.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
module ArrayInterface
22

33
using ArrayInterfaceCore
4-
import ArrayInterfaceCore: axes, axes_types, can_setindex, contiguous_axis, contiguous_batch_size,
5-
defines_strides, dense_dims, device, dimnames, fast_scalar_indexing, findstructralnz,
4+
import ArrayInterfaceCore: axes, axes_types, buffer, can_setindex, contiguous_axis, contiguous_batch_size,
5+
defines_strides, dense_dims, dimnames, fast_matrix_colors, fast_scalar_indexing, findstructralnz,
66
is_lazy_conjugate, length,
7-
has_sparsestruct, lu_instance, matrix_colors, ismutable, restructure, known_first,
7+
has_sparsestruct, isstructured, lu_instance, matrix_colors, ismutable, restructure, known_first,
88
known_last, known_length, known_step, known_size, known_strides, known_offsets, offsets,
99
parent_type, size, strides, stride_rank, to_dims, to_indices, to_index, zeromatrix
10+
import ArrayInterfaceCore: ArrayIndex, MatrixIndex, VectorIndex, BidiagonalIndex, TridiagonalIndex, StrideIndex
11+
import ArrayInterfaceCore: AbstractDevice, AbstractCPU, CPUTuple, CPUPointer, GPU, CPUIndex, CheckParent, device
12+
using LinearAlgebra
1013
using Requires
1114
using Static
1215
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
@@ -15,10 +18,10 @@ using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
1518
function __init__()
1619
@require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin
1720
ismutable(::Type{<:StaticArrays.StaticArray}) = false
18-
can_setindex(::Type{<:StaticArrays.StaticArray}) = false
1921
ismutable(::Type{<:StaticArrays.MArray}) = true
2022
ismutable(::Type{<:StaticArrays.SizedArray}) = true
2123

24+
can_setindex(::Type{<:StaticArrays.StaticArray}) = false
2225
buffer(A::Union{StaticArrays.SArray,StaticArrays.MArray}) = getfield(A, :data)
2326

2427
function lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
@@ -48,7 +51,7 @@ function __init__()
4851
Static.nstatic(Val(N))
4952
end
5053
function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N}
51-
return ArrayInterface._all_dense(Val(N))
54+
return ArrayInterfaceCore._all_dense(Val(N))
5255
end
5356
defines_strides(::Type{<:StaticArrays.SArray}) = true
5457
defines_strides(::Type{<:StaticArrays.MArray}) = true
@@ -448,29 +451,29 @@ function __init__()
448451
return getfield(relative_offsets(A), dim)
449452
end
450453
end
451-
ArrayInterface.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
454+
ArrayInterfaceCore.parent_type(::Type{<:OffsetArrays.OffsetArray{T,N,A}}) where {T,N,A} = A
452455
function _offset_axis_type(::Type{T}, dim::StaticInt{D}) where {T,D}
453-
OffsetArrays.IdOffsetRange{Int,ArrayInterface.axes_types(T, dim)}
456+
OffsetArrays.IdOffsetRange{Int,ArrayInterfaceCore.axes_types(T, dim)}
454457
end
455-
function ArrayInterface.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
456-
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterface.parent_type(T))
458+
function ArrayInterfaceCore.axes_types(::Type{T}) where {T<:OffsetArrays.OffsetArray}
459+
Static.eachop_tuple(_offset_axis_type, Static.nstatic(Val(ndims(T))), ArrayInterfaceCore.parent_type(T))
457460
end
458-
function ArrayInterface.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
461+
function ArrayInterfaceCore.known_offsets(::Type{A}) where {A<:OffsetArrays.OffsetArray}
459462
ntuple(identity -> nothing, Val(ndims(A)))
460463
end
461-
function ArrayInterface.offsets(A::OffsetArrays.OffsetArray)
462-
map(+, ArrayInterface.offsets(parent(A)), relative_offsets(A))
464+
function ArrayInterfaceCore.offsets(A::OffsetArrays.OffsetArray)
465+
map(+, ArrayInterfaceCore.offsets(parent(A)), relative_offsets(A))
463466
end
464-
@inline function ArrayInterface.offsets(A::OffsetArrays.OffsetArray, dim)
465-
d = ArrayInterface.to_dims(A, dim)
466-
ArrayInterface.offsets(parent(A), d) + relative_offsets(A, d)
467+
@inline function ArrayInterfaceCore.offsets(A::OffsetArrays.OffsetArray, dim)
468+
d = ArrayInterfaceCore.to_dims(A, dim)
469+
ArrayInterfaceCore.offsets(parent(A), d) + relative_offsets(A, d)
467470
end
468-
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray)
469-
map(OffsetArrays.IdOffsetRange, ArrayInterface.axes(parent(A)), relative_offsets(A))
471+
@inline function ArrayInterfaceCore.axes(A::OffsetArrays.OffsetArray)
472+
map(OffsetArrays.IdOffsetRange, ArrayInterfaceCore.axes(parent(A)), relative_offsets(A))
470473
end
471-
@inline function ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim)
474+
@inline function ArrayInterfaceCore.axes(A::OffsetArrays.OffsetArray, dim)
472475
d = to_dims(A, dim)
473-
OffsetArrays.IdOffsetRange(ArrayInterface.axes(parent(A), d), relative_offsets(A, d))
476+
OffsetArrays.IdOffsetRange(ArrayInterfaceCore.axes(parent(A), d), relative_offsets(A, d))
474477
end
475478
end
476479
end

test/runtests.jl

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,62 @@ using StaticArrays
1414
using Test
1515

1616
@testset "StaticArrays" begin
17-
@test isone(ArrayInterface.known_first(typeof(StaticArrays.SOneTo(7))))
18-
@test ArrayInterface.known_last(typeof(StaticArrays.SOneTo(7))) == 7
19-
@test ArrayInterface.known_length(typeof(StaticArrays.SOneTo(7))) == 7
17+
x = @SVector [1,2,3]
18+
@test ArrayInterfaceCore.ismutable(x) == false
19+
@test ArrayInterfaceCore.ismutable(view(x, 1:2)) == false
20+
@test ArrayInterfaceCore.can_setindex(typeof(x)) == false
21+
@test ArrayInterfaceCore.buffer(x) == x.data
22+
@test @inferred(ArrayInterfaceCore.device(typeof(x))) === ArrayInterfaceCore.CPUTuple()
23+
24+
x = @MVector [1,2,3]
25+
@test ArrayInterfaceCore.ismutable(x) == true
26+
@test ArrayInterfaceCore.ismutable(view(x, 1:2)) == true
27+
@test @inferred(ArrayInterfaceCore.device(typeof(x))) === ArrayInterfaceCore.CPUPointer()
28+
29+
A = @SMatrix(randn(5, 5))
30+
@test lu_instance(A) isa typeof(lu(A))
31+
A = @MMatrix(randn(5, 5))
32+
@test lu_instance(A) isa typeof(lu(A))
33+
34+
@test isone(ArrayInterfaceCore.known_first(typeof(StaticArrays.SOneTo(7))))
35+
@test ArrayInterfaceCore.known_last(typeof(StaticArrays.SOneTo(7))) == 7
36+
@test ArrayInterfaceCore.known_length(typeof(StaticArrays.SOneTo(7))) == 7
2037

2138
@test parent_type(SizedVector{1, Int, Vector{Int}}) <: Vector{Int}
22-
@test ArrayInterface.known_length(@inferred(ArrayInterface.indices(SOneTo(7)))) == 7
39+
@test ArrayInterfaceCore.known_length(@inferred(ArrayInterfaceCore.indices(SOneTo(7)))) == 7
2340

2441
x = view(SArray{Tuple{3,3,3}}(ones(3,3,3)), :, SOneTo(2), 2)
25-
@test @inferred(ArrayInterface.known_length(x)) == 6
26-
@test @inferred(ArrayInterface.known_length(x')) == 6
42+
@test @inferred(ArrayInterfaceCore.known_length(x)) == 6
43+
@test @inferred(ArrayInterfaceCore.known_length(x')) == 6
2744

2845
v = @SVector rand(8);
2946
A = @MMatrix rand(7, 6);
3047
T = SizedArray{Tuple{5,4,3}}(zeros(5,4,3));
31-
@test @inferred(ArrayInterface.length(v)) === StaticInt(8)
32-
@test @inferred(ArrayInterface.length(A)) === StaticInt(42)
33-
@test @inferred(ArrayInterface.length(T)) === StaticInt(60)
48+
@test @inferred(ArrayInterfaceCore.length(v)) === StaticInt(8)
49+
@test @inferred(ArrayInterfaceCore.length(A)) === StaticInt(42)
50+
@test @inferred(ArrayInterfaceCore.length(T)) === StaticInt(60)
51+
52+
x = @SMatrix rand(Float32, 2, 2)
53+
y = @SVector rand(4)
54+
yr = ArrayInterfaceCore.restructure(x, y)
55+
@test yr isa SMatrix{2, 2}
56+
@test Base.size(yr) == (2,2)
57+
@test vec(yr) == vec(y)
58+
z = rand(4)
59+
zr = ArrayInterfaceCore.restructure(x, z)
60+
@test zr isa SMatrix{2, 2}
61+
@test Base.size(zr) == (2,2)
62+
@test vec(zr) == vec(z)
3463

35-
A = @SMatrix(randn(5, 5))
36-
@test lu_instance(A) isa typeof(lu(A))
37-
A = @MMatrix(randn(5, 5))
38-
@test lu_instance(A) isa typeof(lu(A))
3964
Am = @MMatrix rand(2,10);
40-
@test @inferred(ArrayInterface.strides(view(Am,1,:))) === (StaticInt(2),)
65+
@test @inferred(ArrayInterfaceCore.strides(view(Am,1,:))) === (StaticInt(2),)
4166

42-
@test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(1)
43-
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
44-
@test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterface.StaticInt(0)
67+
@test @inferred(contiguous_axis(@SArray(zeros(2,2,2)))) === ArrayInterfaceCore.StaticInt(1)
68+
@test @inferred(ArrayInterfaceCore.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
69+
@test @inferred(contiguous_batch_size(@SArray(zeros(2,2,2)))) === ArrayInterfaceCore.StaticInt(0)
4570
@test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3)
46-
@test @inferred(ArrayInterface.is_column_major(@SArray(zeros(2,2,2)))) === True()
71+
@test @inferred(ArrayInterfaceCore.is_column_major(@SArray(zeros(2,2,2)))) === True()
4772
@test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == (true,true,true)
48-
49-
x = @SVector [1,2,3]
50-
@test ArrayInterfaceCore.ismutable(x) == false
51-
@test ArrayInterfaceCore.ismutable(view(x, 1:2)) == false
52-
x = @MVector [1,2,3]
53-
@test ArrayInterfaceCore.ismutable(x) == true
54-
@test ArrayInterfaceCore.ismutable(view(x, 1:2)) == true
5573
end
5674

5775
@testset "BandedMatrices" begin
@@ -66,6 +84,8 @@ end
6684
B[band(2)].=[5,6,7,8]
6785
rowind,colind=findstructralnz(B)
6886
@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4]
87+
@test ArrayInterfaceCore.isstructured(typeof(B))
88+
@test ArrayInterfaceCore.fast_matrix_colors(typeof(B))
6989
end
7090

7191
@testset "BlockBandedMatrices" begin
@@ -78,6 +98,9 @@ end
7898
1,1,1,1,1,1,1,1,1,1,1,1,1,1,
7999
1,1,1,1,1,1,1,1,1,1,1,1,1,1,
80100
1,1,1,1,1]
101+
@test ArrayInterfaceCore.isstructured(typeof(BB))
102+
@test has_sparsestruct(typeof(BB))
103+
@test ArrayInterfaceCore.fast_matrix_colors(typeof(BB))
81104

82105
dense=collect(Ones(8,8))
83106
for i in 1:8
@@ -88,27 +111,31 @@ end
88111
@test [BBB[rowind[i],colind[i]] for i in 1:length(rowind)]==
89112
[1,2,3,1,2,3,4,2,3,4,5,6,7,5,6,7,8,6,7,8,
90113
1,2,3,1,2,3,4,2,3,4,5,6,7,5,6,7,8,6,7,8]
114+
@test ArrayInterfaceCore.isstructured(typeof(BBB))
115+
@test has_sparsestruct(typeof(BBB))
116+
@test ArrayInterfaceCore.fast_matrix_colors(typeof(BBB))
91117
end
92118

93119
@testset "OffsetArrays" begin
120+
A = zeros(3, 4, 5);
94121
O = OffsetArray(A, 3, 7, 10);
95122
Op = PermutedDimsArray(O,(3,1,2));
96-
@test @inferred(ArrayInterface.offsets(O)) === (4, 8, 11)
97-
@test @inferred(ArrayInterface.offsets(Op)) === (11, 4, 8)
123+
@test @inferred(ArrayInterfaceCore.offsets(O)) === (4, 8, 11)
124+
@test @inferred(ArrayInterfaceCore.offsets(Op)) === (11, 4, 8)
98125

99-
@test @inferred(ArrayInterface.offsets((1,2,3))) === (StaticInt(1),)
126+
@test @inferred(ArrayInterfaceCore.offsets((1,2,3))) === (StaticInt(1),)
100127
o = OffsetArray(vec(A), 8);
101-
@test @inferred(ArrayInterface.offset1(o)) === 9
128+
@test @inferred(ArrayInterfaceCore.offset1(o)) === 9
102129

103-
@test @inferred(device(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173))) === ArrayInterface.CPUPointer()
104-
@test @inferred(device(view(OffsetArray(A,2,3,-12), 4, :, -11:-9))) === ArrayInterface.CPUPointer()
105-
@test @inferred(device(view(OffsetArray(A,2,3,-12), 3, :, [-11,-10,-9])')) === ArrayInterface.CPUIndex()
130+
@test @inferred(device(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173))) === ArrayInterfaceCore.CPUPointer()
131+
@test @inferred(device(view(OffsetArray(A,2,3,-12), 4, :, -11:-9))) === ArrayInterfaceCore.CPUPointer()
132+
@test @inferred(device(view(OffsetArray(A,2,3,-12), 3, :, [-11,-10,-9])')) === ArrayInterfaceCore.CPUIndex()
106133

107-
@test @inferred(ArrayInterface.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),1)) === Base.Slice(ArrayInterface.OptionallyStaticUnitRange(4,6))
108-
@test @inferred(ArrayInterface.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),2)) === Base.Slice(ArrayInterface.OptionallyStaticUnitRange(-172,-170))
134+
@test @inferred(ArrayInterfaceCore.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),1)) === Base.Slice(ArrayInterfaceCore.OptionallyStaticUnitRange(4,6))
135+
@test @inferred(ArrayInterfaceCore.indices(OffsetArray(view(PermutedDimsArray(A, (3,1,2)), 1, :, 2:4)', 3, -173),2)) === Base.Slice(ArrayInterfaceCore.OptionallyStaticUnitRange(-172,-170))
109136

110-
@test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterface.CPUTuple()
111-
@test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterface.CPUTuple()
112-
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterface.CPUPointer()
137+
@test @inferred(device(OffsetArray(@SArray(zeros(2,2,2)),-123,29,3231))) === ArrayInterfaceCore.CPUTuple()
138+
@test @inferred(device(OffsetArray(@view(@SArray(zeros(2,2,2))[1,1:2,:]),-3,4))) === ArrayInterfaceCore.CPUTuple()
139+
@test @inferred(device(OffsetArray(@MArray(zeros(2,2,2)),8,-2,-5))) === ArrayInterfaceCore.CPUPointer()
113140
end
114141

0 commit comments

Comments
 (0)