diff --git a/src/BlockArrays.jl b/src/BlockArrays.jl index 15341192..808942ab 100644 --- a/src/BlockArrays.jl +++ b/src/BlockArrays.jl @@ -20,16 +20,16 @@ export blockappend!, blockpush!, blockpushfirst!, blockpop!, blockpopfirst! import Base: @propagate_inbounds, Array, AbstractArray, to_indices, to_index, unsafe_indices, first, last, size, length, unsafe_length, unsafe_convert, - getindex, setindex!, ndims, show, view, + getindex, setindex!, ndims, show, print_array, view, step, - broadcast, eltype, convert, similar, + broadcast, eltype, convert, similar, collect, tail, reindex, RangeIndex, Int, Integer, Number, Tuple, +, -, *, /, \, min, max, isless, in, copy, copyto!, axes, @deprecate, - BroadcastStyle, checkbounds, + BroadcastStyle, checkbounds, checkindex, ensure_indexable, oneunit, ones, zeros, intersect, Slice, resize! -using Base: ReshapedArray, dataids, oneto +using Base: ReshapedArray, LogicalIndex, dataids, oneto import Base: (:), IteratorSize, iterate, axes1, strides, isempty import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted, broadcastable diff --git a/src/blockedarray.jl b/src/blockedarray.jl index 8d042f79..06daeb50 100644 --- a/src/blockedarray.jl +++ b/src/blockedarray.jl @@ -193,6 +193,10 @@ AbstractArray{T,N}(A::BlockedArray) where {T,N} = BlockedArray(AbstractArray{T,N copy(A::BlockedArray) = BlockedArray(copy(A.blocks), A.axes) +# Blocked version of `collect(::AbstractArray)` that preserves the +# block structure. +blockcollect(a::AbstractArray) = BlockedArray(collect(a), axes(a)) + Base.dataids(A::BlockedArray) = Base.dataids(A.blocks) ########################### diff --git a/src/views.jl b/src/views.jl index 157acbe5..bd276685 100644 --- a/src/views.jl +++ b/src/views.jl @@ -59,6 +59,50 @@ to_index(::BlockRange) = throw(ArgumentError("BlockRange must be converted by to @inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndex{1}}, Vararg{Any}}) = to_indices(A, axes(A), I) @inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndexRange{1}}, Vararg{Any}}) = to_indices(A, axes(A), I) +## BlockedLogicalIndex +# Blocked version of `LogicalIndex`: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L819-L831 +const BlockedLogicalIndex{T,R<:LogicalIndex{T},BS<:Tuple{AbstractUnitRange{<:Integer}}} = BlockedVector{T,R,BS} +function BlockedLogicalIndex(I::AbstractVector{Bool}) + blocklengths = map(b -> count(view(I, b)), BlockRange(I)) + return BlockedVector(LogicalIndex(I), blocklengths) +end +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L838-L839 +show(io::IO, r::BlockedLogicalIndex) = print(io, blockcollect(r)) +print_array(io::IO, X::BlockedLogicalIndex) = print_array(io, blockcollect(X)) + +# Blocked version of `to_index(::AbstractArray{Bool})`: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/indices.jl#L309 +function to_index(I::AbstractBlockVector{Bool}) + return BlockedLogicalIndex(I) +end + +# Blocked version of `collect(::LogicalIndex)`: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L837 +# Without this definition, `collect` will try to call `getindex` on the `LogicalIndex` +# which isn't defined. +collect(I::BlockedLogicalIndex) = collect(I.blocks) + +# Iteration of BlockedLogicalIndex is just iteration over the underlying +# LogicalIndex, which is implemented here: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L840-L890 +@inline iterate(I::BlockedLogicalIndex) = iterate(I.blocks) +@inline iterate(I::BlockedLogicalIndex, s) = iterate(I.blocks, s) + +## Boundscheck for BlockLogicalindex +# Like for LogicalIndex, map all calls to mask: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L892-L897 +checkbounds(::Type{Bool}, A::AbstractArray, i::BlockedLogicalIndex) = checkbounds(Bool, A, i.blocks.mask) +# `checkbounds_indices` has been handled via `I::AbstractArray` fallback +checkindex(::Type{Bool}, inds::AbstractUnitRange, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask) +checkindex(::Type{Bool}, inds::Tuple, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask) + +# Instantiate the BlockedLogicalIndex when constructing a SubArray, similar to +# `ensure_indexable(I::Tuple{LogicalIndex,Vararg{Any}})`: +# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L918 +@inline ensure_indexable(I::Tuple{BlockedLogicalIndex,Vararg{Any}}) = + (blockcollect(I[1]), ensure_indexable(tail(I))...) + @propagate_inbounds reindex(idxs::Tuple{BlockSlice{<:BlockRange}, Vararg{Any}}, subidxs::Tuple{BlockSlice{<:BlockIndexRange}, Vararg{Any}}) = (BlockSlice(BlockIndexRange(Block(idxs[1].block.indices[1][Int(subidxs[1].block.block)]), diff --git a/test/test_blockarrays.jl b/test/test_blockarrays.jl index 6a5dc6e7..b8c2fc5d 100644 --- a/test/test_blockarrays.jl +++ b/test/test_blockarrays.jl @@ -1,7 +1,7 @@ module TestBlockArrays using SparseArrays, BlockArrays, FillArrays, LinearAlgebra, Test, OffsetArrays, Images -import BlockArrays: _BlockArray +import BlockArrays: _BlockArray, blockcollect const Fill = FillArrays.Fill @@ -255,6 +255,32 @@ end @test zero(b) isa typeof(b) end + @testset "blockcollect" begin + a = randn(6, 6) + @test blockcollect(a) == a + @test blockcollect(a) ≢ a + @test blockcollect(a).blocks ≢ a + # TODO: Maybe special case this to call `collect` and return a `Matrix`? + @test blockcollect(a) isa BlockedMatrix{Float64,Matrix{Float64}} + @test blockisequal(axes(blockcollect(a)), axes(a)) + @test blocksize(blockcollect(a)) == (1, 1) + + b = BlockedArray(randn(6, 6), [3, 3], [3, 3]) + @test blockcollect(b) == b + @test blockcollect(b) ≢ b + @test blockcollect(b).blocks ≢ b + @test blockcollect(b) isa BlockedMatrix{Float64,Matrix{Float64}} + @test blockisequal(axes(blockcollect(b)), axes(b)) + @test blocksize(blockcollect(b)) == (2, 2) + + c = BlockArray(randn(6, 6), [3, 3], [3, 3]) + @test blockcollect(c) == c + @test blockcollect(c) ≢ c + @test blockcollect(c) isa BlockedMatrix{Float64,Matrix{Float64}} + @test blockisequal(axes(blockcollect(c)), axes(c)) + @test blocksize(blockcollect(c)) == (2, 2) + end + @test_throws DimensionMismatch BlockArray([1,2,3],[1,1]) @testset "mortar" begin diff --git a/test/test_blockviews.jl b/test/test_blockviews.jl index f701b7b5..aac1fd5f 100644 --- a/test/test_blockviews.jl +++ b/test/test_blockviews.jl @@ -2,6 +2,8 @@ module TestBlockViews using BlockArrays, ArrayLayouts, Test using FillArrays +import BlockArrays: BlockedLogicalIndex +import Base: LogicalIndex # useds to force SubArray return bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b) @@ -353,6 +355,25 @@ bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b) @test MemoryLayout(v) == MemoryLayout(a) @test v[Block(1)] == a[Block(1)] end + + @testset "BlockedLogicalIndex" begin + a = randn(6, 6) + for mask in ([true, true, false, false, true, false], BitVector([true, true, false, false, true, false])) + I = BlockedVector(mask, [3, 3]) + @test to_indices(a, (I, I)) == to_indices(a, (mask, mask)) + @test to_indices(a, (I, I)) == (BlockedVector(LogicalIndex(mask), [2, 1]), BlockedVector(LogicalIndex(mask), [2, 1])) + @test to_indices(a, (I, I)) isa Tuple{BlockedLogicalIndex{Int},BlockedLogicalIndex{Int}} + @test blocklengths.(Base.axes1.(to_indices(a, (I, I)))) == ([2, 1], [2, 1]) + for b in (view(a, I, I), a[I, I]) + @test size(b) == (3, 3) + @test blocklengths.(axes(b)) == ([2, 1], [2, 1]) + @test b == a[mask, mask] + end + @test parentindices(view(a, I, I)) == (BlockedVector([1, 2, 5], [2, 1]), BlockedVector([1, 2, 5], [2, 1])) + @test parentindices(view(a, I, I)) isa Tuple{BlockedVector{Int,Vector{Int}},BlockedVector{Int,Vector{Int}}} + @test blocklengths.(Base.axes1.(parentindices(view(a, I, I)))) == ([2, 1], [2, 1]) + end + end end end # module