Skip to content

Commit 7d3f1bf

Browse files
authored
Add support for logical indexing that preserves block sparsity (#131)
1 parent 6bee699 commit 7d3f1bf

File tree

12 files changed

+229
-66
lines changed

12 files changed

+229
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.8"
4+
version = "0.6.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,19 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
6767
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
6868
end
6969
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
70+
71+
function _blockslice(x, y::AbstractUnitRange{<:Integer})
72+
return BlockSlice(x, y)
73+
end
74+
function _blockslice(x, y::AbstractVector{<:Integer})
75+
return BlockIndices(x, y)
76+
end
7077
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
7178
# TODO: Check that `i.indices` is consistent with `S.indices`.
7279
# It seems like this isn't handling the case where `i` is a
7380
# subslice of a block correctly (i.e. it ignores `i.indices`).
7481
@assert length(S.indices[Block(i)]) == length(i.indices)
75-
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
82+
return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)])
7683
end
7784

7885
# This is used in slicing like:
@@ -151,9 +158,18 @@ end
151158
const BlockSliceCollection = Union{
152159
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
153160
}
154-
const SubBlockSliceCollection = BlockIndices{
161+
const BlockIndexRangeSlice = BlockSlice{
162+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
163+
}
164+
const BlockIndexRangeSlices = BlockIndices{
155165
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
156166
}
167+
const BlockIndexVectorSlices = BlockIndices{
168+
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
169+
}
170+
const SubBlockSliceCollection = Union{
171+
BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices
172+
}
157173

158174
# TODO: This is type piracy. This is used in `reindex` when making
159175
# views of blocks of sliced block arrays, for example:
@@ -347,7 +363,7 @@ function blockrange(
347363
axis::AbstractUnitRange,
348364
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
349365
)
350-
return map(b -> Block(b), blocks(r))
366+
return map(Block, blocks(r))
351367
end
352368

353369
# This handles slicing with `:`/`Colon()`.
@@ -365,6 +381,17 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
365381
return Block.(Base.OneTo(1))
366382
end
367383

384+
function blockrange(axis::AbstractUnitRange, r::BlockIndexVector)
385+
return Block(r):Block(r)
386+
end
387+
388+
function blockrange(
389+
axis::AbstractUnitRange,
390+
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexVector}},
391+
)
392+
return map(Block, blocks(r))
393+
end
394+
368395
function blockrange(axis::AbstractUnitRange, r)
369396
return error("Slicing not implemented for range of type `$(typeof(r))`.")
370397
end

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using BlockArrays:
1010
BlockVector,
1111
block,
1212
blockedrange,
13+
blockfirsts,
1314
blockindex,
1415
blocklengths,
1516
findblock,
@@ -134,7 +135,7 @@ end
134135

135136
# TODO: Move this to a `BlockArraysExtensions` library.
136137
function blockedunitrange_getindices(
137-
a::AbstractBlockedUnitRange, indices::Vector{<:Integer}
138+
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Integer}
138139
)
139140
return map(index -> a[index], indices)
140141
end
@@ -169,6 +170,18 @@ function blockedunitrange_getindices(
169170
return mortar(map(b -> a[b], blocks(indices)))
170171
end
171172

173+
function blockedunitrange_getindices(
174+
a::AbstractBlockedUnitRange, indices::AbstractVector{Bool}
175+
)
176+
blocked_indices = BlockedVector(indices, axes(a))
177+
bs = map(Base.OneTo(blocklength(blocked_indices))) do b
178+
binds = blocked_indices[Block(b)]
179+
bstart = blockfirsts(only(axes(blocked_indices)))[b]
180+
return findall(binds) .+ (bstart - 1)
181+
end
182+
return mortar(filter(!isempty, bs))
183+
end
184+
172185
# TODO: Move this to a `BlockArraysExtensions` library.
173186
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices)
174187
return error("Not implemented.")
@@ -197,6 +210,26 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:
197210
)
198211
end
199212

213+
struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <:
214+
AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}}
215+
block::Block{1,TB}
216+
indices::I
217+
end
218+
Base.length(a::BlockIndexVector) = length(a.indices)
219+
Base.size(a::BlockIndexVector) = (length(a),)
220+
BlockArrays.Block(a::BlockIndexVector) = a.block
221+
Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]]
222+
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices))
223+
224+
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
225+
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
226+
I′_blocks = map(eachindex(I_blocks)) do b
227+
I_b = findall(I_blocks[b])
228+
BlockIndexVector(Block(b), I_b)
229+
end
230+
return mortar(filter(!isempty, I′_blocks))
231+
end
232+
200233
# This handles non-blocked slices.
201234
# For example:
202235
# a = BlockSparseArray{Float64}([2, 2, 2, 2])

src/abstractblocksparsearray/unblockedsubarray.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
44
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype
55

66
const UnblockedIndices = Union{
7-
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
7+
Vector{<:Integer},
8+
BlockSlice{<:Block{1}},
9+
BlockSlice{<:BlockIndexRange{1}},
10+
BlockSlice{<:BlockIndexVector},
811
}
912

1013
const UnblockedSubArray{T,N} = SubArray{

src/abstractblocksparsearray/views.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,14 @@ end
9292
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
9393
to_block(I::Block{1}) = I
9494
to_block(I::BlockIndexRange{1}) = Block(I)
95+
to_block(I::BlockIndexVector) = Block(I)
9596
to_block_indices(I::Block{1}) = Colon()
9697
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98+
to_block_indices(I::BlockIndexVector) = I.indices
9799

98100
function Base.view(
99-
a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Union{Block{1},BlockIndexRange{1}},N}
101+
a::AbstractBlockSparseArray{<:Any,N},
102+
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
100103
) where {N}
101104
return @views a[to_block.(I)...][to_block_indices.(I)...]
102105
end
@@ -108,7 +111,7 @@ function Base.view(
108111
end
109112
function Base.view(
110113
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}},
111-
I::Vararg{Union{Block{1},BlockIndexRange{1}},N},
114+
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
112115
) where {T,N}
113116
return @views a[to_block.(I)...][to_block_indices.(I)...]
114117
end
@@ -205,8 +208,21 @@ function BlockArrays.viewblock(
205208
end
206209

207210
function to_blockindexrange(
208-
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
209-
I::Block{1},
211+
a::BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, I::Block{1}
212+
)
213+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
214+
# work right now.
215+
return blocks(a.block)[Int(I)]
216+
end
217+
function to_blockindexrange(
218+
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange}}}, I::Block{1}
219+
)
220+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
221+
# work right now.
222+
return blocks(a.blocks)[Int(I)]
223+
end
224+
function to_blockindexrange(
225+
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}}, I::Block{1}
210226
)
211227
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
212228
# work right now.
@@ -245,47 +261,61 @@ function BlockArrays.viewblock(
245261
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
246262
end
247263

248-
# Block slice of the result of slicing `@view a[2:5, 2:5]`.
249-
# TODO: Move this to `BlockArraysExtensions`.
250-
const BlockedSlice = BlockSlice{
251-
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
252-
}
253-
254264
function Base.view(
255-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
265+
a::SubArray{
266+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
267+
},
256268
block::Union{Block{N},BlockIndexRange{N}},
257269
) where {T,N}
258270
return viewblock(a, block)
259271
end
260272
function Base.view(
261-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
262-
block::Vararg{Union{Block{1},BlockIndexRange{1}},N},
273+
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockIndexRangeSlice,N}}},
274+
block::Union{Block{N},BlockIndexRange{N}},
275+
) where {T,N}
276+
return viewblock(a, block)
277+
end
278+
function Base.view(
279+
a::SubArray{
280+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
281+
},
282+
block::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
263283
) where {T,N}
264284
return viewblock(a, block...)
265285
end
266286
function BlockArrays.viewblock(
267-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
287+
a::SubArray{
288+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
289+
},
268290
block::Union{Block{N},BlockIndexRange{N}},
269291
) where {T,N}
270292
return viewblock(a, to_tuple(block)...)
271293
end
294+
295+
blockedslice_blocks(x::BlockSlice) = x.block
296+
blockedslice_blocks(x::BlockIndices) = x.blocks
297+
272298
# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
273299
function BlockArrays.viewblock(
274-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
300+
a::SubArray{
301+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
302+
},
275303
I::Vararg{Block{1},N},
276304
) where {T,N}
277305
# TODO: Use `reindex`, `to_indices`, etc.
278306
brs = ntuple(ndims(a)) do dim
279307
# TODO: Ideally we would use this but it outputs a Vector,
280308
# not a range:
281309
# return parentindices(a)[dim].block[I[dim]]
282-
return blocks(parentindices(a)[dim].block)[Int(I[dim])]
310+
return blocks(blockedslice_blocks(parentindices(a)[dim]))[Int(I[dim])]
283311
end
284312
return @view parent(a)[brs...]
285313
end
286314
# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
287315
function BlockArrays.viewblock(
288-
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
316+
a::SubArray{
317+
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
318+
},
289319
block::Vararg{BlockIndexRange{1},N},
290320
) where {T,N}
291321
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ function Base.to_indices(
4040
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
4141
end
4242

43+
function Base.to_indices(
44+
a::AnyAbstractBlockSparseArray, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
45+
)
46+
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
47+
end
48+
# Fix ambiguity error with Base for logical indexing in Julia 1.10.
49+
# TODO: Delete this once we drop support for Julia 1.10.
50+
function Base.to_indices(
51+
a::AnyAbstractBlockSparseArray, inds, I::Union{Tuple{BitArray{N}},Tuple{Array{Bool,N}}}
52+
) where {N}
53+
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
54+
end
55+
4356
# a[[Block(2), Block(1)], [Block(2), Block(1)]]
4457
function Base.to_indices(
4558
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ end
146146
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
147147
end
148148

149+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
150+
a, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
151+
)
152+
bs1 = to_blockindices(inds[1], I[1])
153+
I1 = BlockIndices(bs1, blockedunitrange_getindices(inds[1], I[1]))
154+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
155+
end
156+
149157
# Special case when there is no blocking.
150158
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
151159
a,

src/factorizations/lq.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1-
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!
2-
3-
# TODO: this is a hardcoded for now to get around this function not being defined in the
4-
# type domain
5-
function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
6-
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
7-
error("unsupported type: $(blocktype(A))")
8-
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
9-
return BlockPermutedDiagonalAlgorithm(alg)
1+
using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full!
2+
3+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
4+
function MatrixAlgebraKit.default_lq_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
5+
return default_lq_algorithm(typeof(A); kwargs...)
106
end
7+
8+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
119
function MatrixAlgebraKit.default_algorithm(
12-
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
10+
::typeof(lq_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1311
)
14-
return default_blocksparse_lq_algorithm(A; kwargs...)
12+
return default_lq_algorithm(A; kwargs...)
1513
end
14+
15+
# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
1616
function MatrixAlgebraKit.default_algorithm(
17-
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
17+
::typeof(lq_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
1818
)
19-
return default_blocksparse_lq_algorithm(A; kwargs...)
19+
return default_lq_algorithm(A; kwargs...)
20+
end
21+
22+
function MatrixAlgebraKit.default_lq_algorithm(
23+
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
24+
)
25+
alg = default_lq_algorithm(blocktype(A); kwargs...)
26+
return BlockPermutedDiagonalAlgorithm(alg)
2027
end
2128

2229
function similar_output(

0 commit comments

Comments
 (0)