|
| 1 | +using BlockArrays: |
| 2 | + BlockArrays, |
| 3 | + AbstractBlockedUnitRange, |
| 4 | + AbstractBlockVector, |
| 5 | + Block, |
| 6 | + BlockIndex, |
| 7 | + BlockIndexRange, |
| 8 | + BlockRange, |
| 9 | + BlockSlice, |
| 10 | + BlockVector, |
| 11 | + block, |
| 12 | + blockindex, |
| 13 | + findblock, |
| 14 | + findblockindex, |
| 15 | + mortar |
| 16 | + |
| 17 | +# Custom `BlockedUnitRange` constructor that takes a unit range |
| 18 | +# and a set of block lengths, similar to `BlockArray(::AbstractArray, blocklengths...)`. |
| 19 | +function blockedunitrange(a::AbstractUnitRange, blocklengths) |
| 20 | + blocklengths_shifted = copy(blocklengths) |
| 21 | + blocklengths_shifted[1] += (first(a) - 1) |
| 22 | + blocklasts = cumsum(blocklengths_shifted) |
| 23 | + return BlockArrays._BlockedUnitRange(first(a), blocklasts) |
| 24 | +end |
| 25 | + |
| 26 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 27 | +# TODO: Rename this. `BlockArrays.findblock(a, k)` finds the |
| 28 | +# block of the value `k`, while this finds the block of the index `k`. |
| 29 | +# This could make use of the `BlockIndices` object, i.e. `block(BlockIndices(a)[index])`. |
| 30 | +function blockedunitrange_findblock(a::AbstractBlockedUnitRange, index::Integer) |
| 31 | + @boundscheck index in 1:length(a) || throw(BoundsError(a, index)) |
| 32 | + return @inbounds findblock(a, index + first(a) - 1) |
| 33 | +end |
| 34 | + |
| 35 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 36 | +# TODO: Rename this. `BlockArrays.findblockindex(a, k)` finds the |
| 37 | +# block index of the value `k`, while this finds the block index of the index `k`. |
| 38 | +# This could make use of the `BlockIndices` object, i.e. `BlockIndices(a)[index]`. |
| 39 | +function blockedunitrange_findblockindex(a::AbstractBlockedUnitRange, index::Integer) |
| 40 | + @boundscheck index in 1:length(a) || throw(BoundsError()) |
| 41 | + return @inbounds findblockindex(a, index + first(a) - 1) |
| 42 | +end |
| 43 | + |
| 44 | +function blockedunitrange_getindices(a::AbstractUnitRange, indices) |
| 45 | + return a[indices] |
| 46 | +end |
| 47 | + |
| 48 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 49 | +# Like `a[indices]` but preserves block structure. |
| 50 | +# TODO: Consider calling this something else, for example |
| 51 | +# `blocked_getindex`. See the discussion here: |
| 52 | +# https://github.com/JuliaArrays/BlockArrays.jl/issues/347 |
| 53 | +function blockedunitrange_getindices( |
| 54 | + a::AbstractBlockedUnitRange, indices::AbstractUnitRange{<:Integer} |
| 55 | +) |
| 56 | + first_blockindex = blockedunitrange_findblockindex(a, first(indices)) |
| 57 | + last_blockindex = blockedunitrange_findblockindex(a, last(indices)) |
| 58 | + first_block = block(first_blockindex) |
| 59 | + last_block = block(last_blockindex) |
| 60 | + blocklengths = if first_block == last_block |
| 61 | + [length(indices)] |
| 62 | + else |
| 63 | + map(first_block:last_block) do block |
| 64 | + if block == first_block |
| 65 | + return length(a[first_block]) - blockindex(first_blockindex) + 1 |
| 66 | + end |
| 67 | + if block == last_block |
| 68 | + return blockindex(last_blockindex) |
| 69 | + end |
| 70 | + return length(a[block]) |
| 71 | + end |
| 72 | + end |
| 73 | + return blockedunitrange(indices .+ (first(a) - 1), blocklengths) |
| 74 | +end |
| 75 | + |
| 76 | +# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. |
| 77 | +# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`? |
| 78 | +# For example: |
| 79 | +# ```julia |
| 80 | +# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) |
| 81 | +# return blockedrange(blocklengths) |
| 82 | +# ``` |
| 83 | +function blockedunitrange_getindices( |
| 84 | + a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}} |
| 85 | +) |
| 86 | + blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) |
| 87 | + # We pass `length.(blks)` to `mortar` in order |
| 88 | + # to pass block labels to the axes of the output, |
| 89 | + # if they exist. This makes it so that |
| 90 | + # `only(axes(a[indices])) isa `GradedUnitRange` |
| 91 | + # if `a isa `GradedUnitRange`, for example. |
| 92 | + # Note there is a more specialized definition: |
| 93 | + # ```julia |
| 94 | + # function blockedunitrange_getindices( |
| 95 | + # a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}} |
| 96 | + # ) |
| 97 | + # ``` |
| 98 | + # that does a better job of preserving labels, since `length` |
| 99 | + # may drop labels for certain block types. |
| 100 | + return mortar(blks, length.(blks)) |
| 101 | +end |
| 102 | + |
| 103 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 104 | +function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndexRange) |
| 105 | + return a[block(indices)][only(indices.indices)] |
| 106 | +end |
| 107 | + |
| 108 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 109 | +function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockSlice) |
| 110 | + # TODO: Is this a good definition? It ignores `indices.indices`. |
| 111 | + return a[indices.block] |
| 112 | +end |
| 113 | + |
| 114 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 115 | +function blockedunitrange_getindices( |
| 116 | + a::AbstractBlockedUnitRange, indices::Vector{<:Integer} |
| 117 | +) |
| 118 | + return map(index -> a[index], indices) |
| 119 | +end |
| 120 | + |
| 121 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 122 | +# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order |
| 123 | +# to merge blocks. |
| 124 | +function blockedunitrange_getindices( |
| 125 | + a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}} |
| 126 | +) |
| 127 | + # Without converting `indices` to `Vector`, |
| 128 | + # mapping `indices` outputs a `BlockVector` |
| 129 | + # which is harder to reason about. |
| 130 | + blocks = map(index -> a[index], Vector(indices)) |
| 131 | + # We pass `length.(blocks)` to `mortar` in order |
| 132 | + # to pass block labels to the axes of the output, |
| 133 | + # if they exist. This makes it so that |
| 134 | + # `only(axes(a[indices])) isa `GradedUnitRange` |
| 135 | + # if `a isa `GradedUnitRange`, for example. |
| 136 | + return mortar(blocks, length.(blocks)) |
| 137 | +end |
| 138 | + |
| 139 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 140 | +function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::Block{1}) |
| 141 | + return a[indices] |
| 142 | +end |
| 143 | + |
| 144 | +function blockedunitrange_getindices( |
| 145 | + a::AbstractBlockedUnitRange, |
| 146 | + indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}, |
| 147 | +) |
| 148 | + return mortar(map(b -> a[b], blocks(indices))) |
| 149 | +end |
| 150 | + |
| 151 | +# TODO: Move this to a `BlockArraysExtensions` library. |
| 152 | +function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices) |
| 153 | + return error("Not implemented.") |
| 154 | +end |
| 155 | + |
| 156 | +# The blocks of the corresponding slice. |
| 157 | +_blocks(a::AbstractUnitRange, indices) = error("Not implemented") |
| 158 | +function _blocks(a::AbstractUnitRange, indices::AbstractUnitRange) |
| 159 | + return findblock(a, first(indices)):findblock(a, last(indices)) |
| 160 | +end |
| 161 | +function _blocks(a::AbstractUnitRange, indices::BlockRange) |
| 162 | + return indices |
| 163 | +end |
| 164 | + |
| 165 | +# Slice `a` by `I`, returning a: |
| 166 | +# `BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}` |
| 167 | +# with the `BlockIndex{1}` corresponding to each value of `I`. |
| 168 | +function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:Integer}) |
| 169 | + return mortar( |
| 170 | + map(blocks(blockedunitrange_getindices(a, I))) do r |
| 171 | + bi_first = findblockindex(a, first(r)) |
| 172 | + bi_last = findblockindex(a, last(r)) |
| 173 | + @assert block(bi_first) == block(bi_last) |
| 174 | + return block(bi_first)[blockindex(bi_first):blockindex(bi_last)] |
| 175 | + end, |
| 176 | + ) |
| 177 | +end |
| 178 | + |
| 179 | +# This handles non-blocked slices. |
| 180 | +# For example: |
| 181 | +# a = BlockSparseArray{Float64}([2, 2, 2, 2]) |
| 182 | +# I = BlockedVector(Block.(1:4), [2, 2]) |
| 183 | +# @views a[I][Block(1)] |
| 184 | +to_blockindices(a::Base.OneTo{<:Integer}, I::UnitRange{<:Integer}) = I |
0 commit comments