Skip to content

Commit 37c843f

Browse files
committed
[WIP] Logical indexing
1 parent d373143 commit 37c843f

File tree

6 files changed

+71
-4
lines changed

6 files changed

+71
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.9"
4+
version = "0.6.10"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,19 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
6767
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
6868
end
6969
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
70+
71+
function _blockslice(x, y::AbstractUnitRange{<:Integer})
72+
return BlockSlice(x, y)
73+
end
74+
function _blockslice(x, y::AbstractVector{<:Integer})
75+
return BlockIndices(x, y)
76+
end
7077
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
7178
# TODO: Check that `i.indices` is consistent with `S.indices`.
7279
# It seems like this isn't handling the case where `i` is a
7380
# subslice of a block correctly (i.e. it ignores `i.indices`).
7481
@assert length(S.indices[Block(i)]) == length(i.indices)
75-
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
82+
return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)])
7683
end
7784

7885
# This is used in slicing like:
@@ -347,7 +354,7 @@ function blockrange(
347354
axis::AbstractUnitRange,
348355
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
349356
)
350-
return map(b -> Block(b), blocks(r))
357+
return map(Block, blocks(r))
351358
end
352359

353360
# This handles slicing with `:`/`Colon()`.
@@ -365,6 +372,14 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
365372
return Block.(Base.OneTo(1))
366373
end
367374

375+
function blockrange(axis::AbstractUnitRange, r::BlockIndexVector)
376+
return Block(r):Block(r)
377+
end
378+
379+
function blockrange(axis::AbstractUnitRange, r::Vector{<:BlockIndexVector})
380+
return map(Block, r)
381+
end
382+
368383
function blockrange(axis::AbstractUnitRange, r)
369384
return error("Slicing not implemented for range of type `$(typeof(r))`.")
370385
end

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using BlockArrays:
1010
BlockVector,
1111
block,
1212
blockedrange,
13+
blockfirsts,
1314
blockindex,
1415
blocklengths,
1516
findblock,
@@ -169,6 +170,19 @@ function blockedunitrange_getindices(
169170
return mortar(map(b -> a[b], blocks(indices)))
170171
end
171172

173+
function blockedunitrange_getindices(
174+
a::AbstractBlockedUnitRange, indices::AbstractArray{Bool}
175+
)
176+
blocked_indices = BlockedVector(indices, axes(a))
177+
return mortar(
178+
map(Base.OneTo(blocklength(blocked_indices))) do b
179+
binds = blocked_indices[Block(b)]
180+
bstart = blockfirsts(only(axes(blocked_indices)))[b]
181+
return findall(binds) .+ (bstart - 1)
182+
end,
183+
)
184+
end
185+
172186
# TODO: Move this to a `BlockArraysExtensions` library.
173187
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices)
174188
return error("Not implemented.")
@@ -197,6 +211,27 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:
197211
)
198212
end
199213

214+
struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <:
215+
AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}}
216+
block::Block{1,TB}
217+
indices::I
218+
end
219+
Base.length(a::BlockIndexVector) = length(a.indices)
220+
Base.size(a::BlockIndexVector) = (length(a),)
221+
BlockArrays.Block(a::BlockIndexVector) = a.block
222+
Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]]
223+
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices))
224+
225+
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
226+
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
227+
return mortar(
228+
map(eachindex(I_blocks)) do b
229+
I_b = findall(I_blocks[b])
230+
BlockIndexVector(Block(b), I_b)
231+
end,
232+
)
233+
end
234+
200235
# This handles non-blocked slices.
201236
# For example:
202237
# a = BlockSparseArray{Float64}([2, 2, 2, 2])

src/abstractblocksparsearray/unblockedsubarray.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
44
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype
55

66
const UnblockedIndices = Union{
7-
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
7+
Vector{<:Integer},
8+
BlockSlice{<:Block{1}},
9+
BlockSlice{<:BlockIndexRange{1}},
10+
BlockSlice{<:BlockIndexVector},
811
}
912

1013
const UnblockedSubArray{T,N} = SubArray{

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ function Base.to_indices(
4040
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
4141
end
4242

43+
function Base.to_indices(
44+
a::AnyAbstractBlockSparseArray, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
45+
)
46+
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
47+
end
48+
4349
# a[[Block(2), Block(1)], [Block(2), Block(1)]]
4450
function Base.to_indices(
4551
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ end
146146
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
147147
end
148148

149+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
150+
a, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
151+
)
152+
bs1 = to_blockindices(inds[1], I[1])
153+
I1 = BlockIndices(bs1, blockedunitrange_getindices(inds[1], I[1]))
154+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
155+
end
156+
149157
# Special case when there is no blocking.
150158
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
151159
a,

0 commit comments

Comments
 (0)