Skip to content

Commit 18687b2

Browse files
committed
Start adding tests
1 parent 7ca3801 commit 18687b2

File tree

3 files changed

+161
-15
lines changed

3 files changed

+161
-15
lines changed

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -292,35 +292,54 @@ end
292292
return b[GenericBlockIndex(tuple(K, J...))]
293293
end
294294

295-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
296-
return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
297-
end
298-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
299-
return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
300-
end
301-
302-
struct BlockIndexVector{N,I<:NTuple{N,AbstractVector},TB<:Integer,BT} <: AbstractArray{BT,N}
295+
# Work around the fact that it is type piracy to define
296+
# `Base.getindex(a::Block, b...)`.
297+
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
298+
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
299+
# Fix ambiguity.
300+
_getindex(a::Block{0}) = a[]
301+
302+
## function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
303+
## return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
304+
## end
305+
## function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
306+
## return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
307+
## end
308+
309+
struct BlockIndexVector{N,BT,I<:NTuple{N,AbstractVector},TB<:Integer} <: AbstractArray{BT,N}
303310
block::Block{N,TB}
304311
indices::I
305-
function BlockIndexVector(
312+
function BlockIndexVector{N,BT}(
306313
block::Block{N,TB}, indices::I
307-
) where {N,I<:NTuple{N,AbstractVector},TB<:Integer}
308-
BT = blockindextype(TB, eltype.(indices)...)
309-
return new{N,I,TB,BT}(block, indices)
314+
) where {N,BT,I<:NTuple{N,AbstractVector},TB<:Integer}
315+
return new{N,BT,I,TB}(block, indices)
310316
end
311317
end
318+
function BlockIndexVector{1,BT}(block::Block{1}, indices::AbstractVector) where {BT}
319+
return BlockIndexVector{1,BT}(block, (indices,))
320+
end
321+
function BlockIndexVector(
322+
block::Block{N,TB}, indices::NTuple{N,AbstractVector}
323+
) where {N,TB<:Integer}
324+
BT = Base.promote_op(_getindex, typeof(block), eltype.(indices)...)
325+
return BlockIndexVector{N,BT}(block, indices)
326+
end
312327
function BlockIndexVector(block::Block{1}, indices::AbstractVector)
313328
return BlockIndexVector(block, (indices,))
314329
end
315330
Base.size(a::BlockIndexVector) = length.(a.indices)
316331
function Base.getindex(a::BlockIndexVector{N}, I::Vararg{Integer,N}) where {N}
317-
return a.block[map((r, i) -> r[i], a.indices, I)...]
332+
return _getindex(Block(a), getindex.(a.indices, I)...)
318333
end
319334
BlockArrays.block(b::BlockIndexVector) = b.block
320335
BlockArrays.Block(b::BlockIndexVector) = b.block
321336

322337
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))
323338

339+
function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
340+
b[block(Kkr)][Kkr.indices...]
341+
end
342+
324343
using ArrayLayouts: LayoutArray
325344
@propagate_inbounds Base.getindex(b::AbstractArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block(
326345
K

src/abstractblocksparsearray/views.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ end
9292
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
9393
to_block(I::Block{1}) = I
9494
to_block(I::BlockIndexRange{1}) = Block(I)
95-
to_block(I::BlockIndexVector) = Block(I)
95+
to_block(I::BlockIndexVector{1}) = Block(I)
9696
to_block_indices(I::Block{1}) = Colon()
9797
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98-
to_block_indices(I::BlockIndexVector) = only(I.indices)
98+
to_block_indices(I::BlockIndexVector{1}) = only(I.indices)
9999

100100
function Base.view(
101101
a::AbstractBlockSparseArray{<:Any,N},

test/test_genericblockindex.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using BlockArrays: Block, BlockIndex, BlockedVector, block, blockindex
2+
using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex
3+
using Test: @test, @test_broken, @testset
4+
5+
# blockrange
6+
# checkindex
7+
# to_indices
8+
# to_index
9+
# blockedunitrange_getindices
10+
# viewblock
11+
# to_blockindexrange
12+
13+
@testset "GenericBlockIndex" begin
14+
i1 = GenericBlockIndex(Block(1), ("x",))
15+
i2 = GenericBlockIndex(Block(2), ("y",))
16+
i = GenericBlockIndex(Block(1, 2), ("x", "y"))
17+
@test sprint(show, i) == "Block(1, 2)[x, y]"
18+
@test i isa GenericBlockIndex{2,Tuple{Int64,Int64},Tuple{String,String}}
19+
@test GenericBlockIndex(Block(1), "x") === i1
20+
@test GenericBlockIndex(1, "x") === i1
21+
@test GenericBlockIndex(1, ("x",)) === i1
22+
@test GenericBlockIndex((1,), "x") === i1
23+
@test GenericBlockIndex((1, 2), ("x", "y")) === i
24+
@test GenericBlockIndex((Block(1), Block(2)), ("x", "y")) === i
25+
@test GenericBlockIndex((i1, i2)) === i
26+
@test block(i1) == Block(1)
27+
@test block(i) == Block(1, 2)
28+
@test blockindex(i1) == "x"
29+
@test GenericBlockIndex((), ()) == GenericBlockIndex(Block(), ())
30+
@test GenericBlockIndex(Block(1, 2), ("x",)) == GenericBlockIndex(Block(1, 2), ("x", 1))
31+
32+
i1 = GenericBlockIndex(Block(1), (1,))
33+
i2 = GenericBlockIndex(Block(2), (2,))
34+
i = GenericBlockIndex(Block(1, 2), (1, 2))
35+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
36+
@test v[i1] == "a"
37+
@test v[i2] == "d"
38+
39+
a = collect(Iterators.product(v, v))
40+
@test a[i1, i1] == ("a", "a")
41+
@test a[i2, i1] == ("d", "a")
42+
@test a[i1, i2] == ("a", "d")
43+
@test a[i] == ("a", "d")
44+
@test a[i2, i2] == ("d", "d")
45+
46+
I = BlockIndexVector(Block(1), [1, 2])
47+
@test eltype(I) === BlockIndex{1,Tuple{Int},Tuple{Int}}
48+
@test ndims(I) === 1
49+
@test length(I) === 2
50+
@test size(I) === (2,)
51+
@test I[1] === Block(1)[1]
52+
@test I[2] === Block(1)[2]
53+
@test block(I) === Block(1)
54+
@test Block(I) === Block(1)
55+
@test copy(I) == BlockIndexVector(Block(1), [1, 2])
56+
57+
I = BlockIndexVector(Block(1, 2), ([1, 2], [3, 4]))
58+
@test eltype(I) === BlockIndex{2,Tuple{Int,Int},Tuple{Int,Int}}
59+
@test ndims(I) === 2
60+
@test length(I) === 4
61+
@test size(I) === (2, 2)
62+
@test I[1, 1] === Block(1, 2)[1, 3]
63+
@test I[2, 1] === Block(1, 2)[2, 3]
64+
@test I[1, 2] === Block(1, 2)[1, 4]
65+
@test I[2, 2] === Block(1, 2)[2, 4]
66+
@test block(I) === Block(1, 2)
67+
@test Block(I) === Block(1, 2)
68+
@test copy(I) == BlockIndexVector(Block(1, 2), ([1, 2], [3, 4]))
69+
70+
I = BlockIndexVector(Block(1), ["x", "y"])
71+
@test eltype(I) === GenericBlockIndex{1,Tuple{Int},Tuple{String}}
72+
@test ndims(I) === 1
73+
@test length(I) === 2
74+
@test size(I) === (2,)
75+
@test I[1] === GenericBlockIndex(Block(1), "x")
76+
@test I[2] === GenericBlockIndex(Block(1), "y")
77+
@test block(I) === Block(1)
78+
@test Block(I) === Block(1)
79+
@test copy(I) == BlockIndexVector(Block(1), ["x", "y"])
80+
81+
I = BlockIndexVector(Block(1, 2), (["x", "y"], ["z", "w"]))
82+
@test eltype(I) === GenericBlockIndex{2,Tuple{Int,Int},Tuple{String,String}}
83+
@test ndims(I) === 2
84+
@test length(I) === 4
85+
@test size(I) === (2, 2)
86+
@test I[1, 1] === GenericBlockIndex(Block(1, 2), ("x", "z"))
87+
@test I[2, 1] === GenericBlockIndex(Block(1, 2), ("y", "z"))
88+
@test I[1, 2] === GenericBlockIndex(Block(1, 2), ("x", "w"))
89+
@test I[2, 2] === GenericBlockIndex(Block(1, 2), ("y", "w"))
90+
@test block(I) === Block(1, 2)
91+
@test Block(I) === Block(1, 2)
92+
@test copy(I) == BlockIndexVector(Block(1, 2), (["x", "y"], ["z", "w"]))
93+
94+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
95+
i = BlockIndexVector(Block(1), [2, 1])
96+
@test v[i] == ["b", "a"]
97+
i = BlockIndexVector(Block(2), [2, 1])
98+
@test v[i] == ["d", "c"]
99+
100+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
101+
i = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(1), [2, 1])
102+
@test v[i] == ["b", "a"]
103+
i = BlockIndexVector(Block(2), [2, 1])
104+
@test v[i] == ["d", "c"]
105+
106+
a = collect(Iterators.product(v, v))
107+
i1 = BlockIndexVector(Block(1), [2, 1])
108+
i2 = BlockIndexVector(Block(2), [1, 2])
109+
i = BlockIndexVector(Block(1, 2), ([2, 1], [1, 2]))
110+
@test a[i1, i1] == [("b", "b") ("b", "a"); ("a", "b") ("a", "a")]
111+
@test a[i2, i1] == [("c", "b") ("c", "a"); ("d", "b") ("d", "a")]
112+
@test a[i1, i2] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
113+
@test a[i] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
114+
@test a[i2, i2] == [("c", "c") ("c", "d"); ("d", "c") ("d", "d")]
115+
116+
a = collect(Iterators.product(v, v))
117+
i1 = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(1), [2, 1])
118+
i2 = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(2), [1, 2])
119+
i = BlockIndexVector{2,GenericBlockIndex{2,Tuple{Int,Int},Tuple{String,String}}}(
120+
Block(1, 2), ([2, 1], [1, 2])
121+
)
122+
@test a[i1, i1] == [("b", "b") ("b", "a"); ("a", "b") ("a", "a")]
123+
@test a[i2, i1] == [("c", "b") ("c", "a"); ("d", "b") ("d", "a")]
124+
@test a[i1, i2] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
125+
@test a[i] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
126+
@test a[i2, i2] == [("c", "c") ("c", "d"); ("d", "c") ("d", "d")]
127+
end

0 commit comments

Comments
 (0)