Skip to content

Add support for logical indexing that preserves block sparsity #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.6.8"
version = "0.6.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
33 changes: 30 additions & 3 deletions src/BlockArraysExtensions/BlockArraysExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,19 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)

function _blockslice(x, y::AbstractUnitRange{<:Integer})
return BlockSlice(x, y)
end
function _blockslice(x, y::AbstractVector{<:Integer})
return BlockIndices(x, y)
end
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# It seems like this isn't handling the case where `i` is a
# subslice of a block correctly (i.e. it ignores `i.indices`).
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end

# This is used in slicing like:
Expand Down Expand Up @@ -151,9 +158,18 @@ end
const BlockSliceCollection = Union{
Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}}
}
const SubBlockSliceCollection = BlockIndices{
const BlockIndexRangeSlice = BlockSlice{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}
const BlockIndexRangeSlices = BlockIndices{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}
const BlockIndexVectorSlices = BlockIndices{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
}
const SubBlockSliceCollection = Union{
BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices
}

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
Expand Down Expand Up @@ -347,7 +363,7 @@ function blockrange(
axis::AbstractUnitRange,
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
)
return map(b -> Block(b), blocks(r))
return map(Block, blocks(r))
end

# This handles slicing with `:`/`Colon()`.
Expand All @@ -365,6 +381,17 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer})
return Block.(Base.OneTo(1))
end

function blockrange(axis::AbstractUnitRange, r::BlockIndexVector)
return Block(r):Block(r)
end

function blockrange(
axis::AbstractUnitRange,
r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexVector}},
)
return map(Block, blocks(r))
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end
Expand Down
35 changes: 34 additions & 1 deletion src/BlockArraysExtensions/blockedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using BlockArrays:
BlockVector,
block,
blockedrange,
blockfirsts,
blockindex,
blocklengths,
findblock,
Expand Down Expand Up @@ -134,7 +135,7 @@ end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::Vector{<:Integer}
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Integer}
)
return map(index -> a[index], indices)
end
Expand Down Expand Up @@ -169,6 +170,18 @@ function blockedunitrange_getindices(
return mortar(map(b -> a[b], blocks(indices)))
end

function blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::AbstractVector{Bool}
)
blocked_indices = BlockedVector(indices, axes(a))
bs = map(Base.OneTo(blocklength(blocked_indices))) do b
binds = blocked_indices[Block(b)]
bstart = blockfirsts(only(axes(blocked_indices)))[b]
return findall(binds) .+ (bstart - 1)
end
return mortar(filter(!isempty, bs))
end

# TODO: Move this to a `BlockArraysExtensions` library.
function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices)
return error("Not implemented.")
Expand Down Expand Up @@ -197,6 +210,26 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<:
)
end

struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <:
AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}}
block::Block{1,TB}
indices::I
end
Base.length(a::BlockIndexVector) = length(a.indices)
Base.size(a::BlockIndexVector) = (length(a),)
BlockArrays.Block(a::BlockIndexVector) = a.block
Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]]
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices))

function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
I′_blocks = map(eachindex(I_blocks)) do b
I_b = findall(I_blocks[b])
BlockIndexVector(Block(b), I_b)
end
return mortar(filter(!isempty, I′_blocks))
end

# This handles non-blocked slices.
# For example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2])
Expand Down
5 changes: 4 additions & 1 deletion src/abstractblocksparsearray/unblockedsubarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice
using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype

const UnblockedIndices = Union{
Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}}
Vector{<:Integer},
BlockSlice{<:Block{1}},
BlockSlice{<:BlockIndexRange{1}},
BlockSlice{<:BlockIndexVector},
}

const UnblockedSubArray{T,N} = SubArray{
Expand Down
64 changes: 47 additions & 17 deletions src/abstractblocksparsearray/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ end
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
to_block(I::Block{1}) = I
to_block(I::BlockIndexRange{1}) = Block(I)
to_block(I::BlockIndexVector) = Block(I)
to_block_indices(I::Block{1}) = Colon()
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
to_block_indices(I::BlockIndexVector) = I.indices

function Base.view(
a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Union{Block{1},BlockIndexRange{1}},N}
a::AbstractBlockSparseArray{<:Any,N},
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
) where {N}
return @views a[to_block.(I)...][to_block_indices.(I)...]
end
Expand All @@ -108,7 +111,7 @@ function Base.view(
end
function Base.view(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}},
I::Vararg{Union{Block{1},BlockIndexRange{1}},N},
I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
) where {T,N}
return @views a[to_block.(I)...][to_block_indices.(I)...]
end
Expand Down Expand Up @@ -205,8 +208,21 @@ function BlockArrays.viewblock(
end

function to_blockindexrange(
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}},
I::Block{1},
a::BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, I::Block{1}
)
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
# work right now.
return blocks(a.block)[Int(I)]
end
function to_blockindexrange(
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange}}}, I::Block{1}
)
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
# work right now.
return blocks(a.blocks)[Int(I)]
end
function to_blockindexrange(
a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}}, I::Block{1}
)
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
# work right now.
Expand Down Expand Up @@ -245,47 +261,61 @@ function BlockArrays.viewblock(
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
end

# Block slice of the result of slicing `@view a[2:5, 2:5]`.
# TODO: Move this to `BlockArraysExtensions`.
const BlockedSlice = BlockSlice{
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}
}

function Base.view(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
block::Union{Block{N},BlockIndexRange{N}},
) where {T,N}
return viewblock(a, block)
end
function Base.view(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
block::Vararg{Union{Block{1},BlockIndexRange{1}},N},
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockIndexRangeSlice,N}}},
block::Union{Block{N},BlockIndexRange{N}},
) where {T,N}
return viewblock(a, block)
end
function Base.view(
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
block::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N},
) where {T,N}
return viewblock(a, block...)
end
function BlockArrays.viewblock(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
block::Union{Block{N},BlockIndexRange{N}},
) where {T,N}
return viewblock(a, to_tuple(block)...)
end

blockedslice_blocks(x::BlockSlice) = x.block
blockedslice_blocks(x::BlockIndices) = x.blocks

# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
function BlockArrays.viewblock(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
I::Vararg{Block{1},N},
) where {T,N}
# TODO: Use `reindex`, `to_indices`, etc.
brs = ntuple(ndims(a)) do dim
# TODO: Ideally we would use this but it outputs a Vector,
# not a range:
# return parentindices(a)[dim].block[I[dim]]
return blocks(parentindices(a)[dim].block)[Int(I[dim])]
return blocks(blockedslice_blocks(parentindices(a)[dim]))[Int(I[dim])]
end
return @view parent(a)[brs...]
end
# TODO: Define `@interface BlockSparseArrayInterface() viewblock`.
function BlockArrays.viewblock(
a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}},
a::SubArray{
T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}}
},
block::Vararg{BlockIndexRange{1},N},
) where {T,N}
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
Expand Down
13 changes: 13 additions & 0 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ function Base.to_indices(
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
end

function Base.to_indices(
a::AnyAbstractBlockSparseArray, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
)
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
end
# Fix ambiguity error with Base for logical indexing in Julia 1.10.
# TODO: Delete this once we drop support for Julia 1.10.
function Base.to_indices(
a::AnyAbstractBlockSparseArray, inds, I::Union{Tuple{BitArray{N}},Tuple{Array{Bool,N}}}
) where {N}
return @interface BlockSparseArrayInterface() to_indices(a, inds, I)
end

# a[[Block(2), Block(1)], [Block(2), Block(1)]]
function Base.to_indices(
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}}
Expand Down
8 changes: 8 additions & 0 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ end
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
a, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}}
)
bs1 = to_blockindices(inds[1], I[1])
I1 = BlockIndices(bs1, blockedunitrange_getindices(inds[1], I[1]))
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
end

# Special case when there is no blocking.
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
a,
Expand Down
33 changes: 20 additions & 13 deletions src/factorizations/lq.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full!

# TODO: this is a hardcoded for now to get around this function not being defined in the
# type domain
function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...)
blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} ||
error("unsupported type: $(blocktype(A))")
alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...)
return BlockPermutedDiagonalAlgorithm(alg)
using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full!

# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
function MatrixAlgebraKit.default_lq_algorithm(A::AbstractBlockSparseMatrix; kwargs...)
return default_lq_algorithm(typeof(A); kwargs...)
end

# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
function MatrixAlgebraKit.default_algorithm(
::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs...
::typeof(lq_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
)
return default_blocksparse_lq_algorithm(A; kwargs...)
return default_lq_algorithm(A; kwargs...)
end

# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged.
function MatrixAlgebraKit.default_algorithm(
::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs...
::typeof(lq_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs...
)
return default_blocksparse_lq_algorithm(A; kwargs...)
return default_lq_algorithm(A; kwargs...)
end

function MatrixAlgebraKit.default_lq_algorithm(
A::Type{<:AbstractBlockSparseMatrix}; kwargs...
)
alg = default_lq_algorithm(blocktype(A); kwargs...)
return BlockPermutedDiagonalAlgorithm(alg)
end

function similar_output(
Expand Down
Loading
Loading