diff --git a/Project.toml b/Project.toml index c15f9724..327cea0e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.6.8" +version = "0.6.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/BlockArraysExtensions.jl b/src/BlockArraysExtensions/BlockArraysExtensions.jl index b4c2d926..5c9f1231 100644 --- a/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -67,12 +67,19 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe @eval Base.$f(S::BlockIndices) = Base.$f(S.indices) end Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i) + +function _blockslice(x, y::AbstractUnitRange{<:Integer}) + return BlockSlice(x, y) +end +function _blockslice(x, y::AbstractVector{<:Integer}) + return BlockIndices(x, y) +end function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}}) # TODO: Check that `i.indices` is consistent with `S.indices`. # It seems like this isn't handling the case where `i` is a # subslice of a block correctly (i.e. it ignores `i.indices`). @assert length(S.indices[Block(i)]) == length(i.indices) - return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)]) + return _blockslice(S.blocks[Int(Block(i))], S.indices[Block(i)]) end # This is used in slicing like: @@ -151,9 +158,18 @@ end const BlockSliceCollection = Union{ Base.Slice,BlockSlice{<:BlockRange{1}},BlockIndices{<:Vector{<:Block{1}}} } -const SubBlockSliceCollection = BlockIndices{ +const BlockIndexRangeSlice = BlockSlice{ + <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}} +} +const BlockIndexRangeSlices = BlockIndices{ <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}} } +const BlockIndexVectorSlices = BlockIndices{ + <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}} +} +const SubBlockSliceCollection = Union{ + BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices +} # TODO: This is type piracy. This is used in `reindex` when making # views of blocks of sliced block arrays, for example: @@ -347,7 +363,7 @@ function blockrange( axis::AbstractUnitRange, r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}}, ) - return map(b -> Block(b), blocks(r)) + return map(Block, blocks(r)) end # This handles slicing with `:`/`Colon()`. @@ -365,6 +381,17 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Integer}) return Block.(Base.OneTo(1)) end +function blockrange(axis::AbstractUnitRange, r::BlockIndexVector) + return Block(r):Block(r) +end + +function blockrange( + axis::AbstractUnitRange, + r::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexVector}}, +) + return map(Block, blocks(r)) +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index f685b761..2e916729 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -10,6 +10,7 @@ using BlockArrays: BlockVector, block, blockedrange, + blockfirsts, blockindex, blocklengths, findblock, @@ -134,7 +135,7 @@ end # TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices( - a::AbstractBlockedUnitRange, indices::Vector{<:Integer} + a::AbstractBlockedUnitRange, indices::AbstractVector{<:Integer} ) return map(index -> a[index], indices) end @@ -169,6 +170,18 @@ function blockedunitrange_getindices( return mortar(map(b -> a[b], blocks(indices))) end +function blockedunitrange_getindices( + a::AbstractBlockedUnitRange, indices::AbstractVector{Bool} +) + blocked_indices = BlockedVector(indices, axes(a)) + bs = map(Base.OneTo(blocklength(blocked_indices))) do b + binds = blocked_indices[Block(b)] + bstart = blockfirsts(only(axes(blocked_indices)))[b] + return findall(binds) .+ (bstart - 1) + end + return mortar(filter(!isempty, bs)) +end + # TODO: Move this to a `BlockArraysExtensions` library. function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices) return error("Not implemented.") @@ -197,6 +210,26 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<: ) end +struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <: + AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}} + block::Block{1,TB} + indices::I +end +Base.length(a::BlockIndexVector) = length(a.indices) +Base.size(a::BlockIndexVector) = (length(a),) +BlockArrays.Block(a::BlockIndexVector) = a.block +Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]] +Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices)) + +function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool}) + I_blocks = blocks(BlockedVector(I, blocklengths(a))) + I′_blocks = map(eachindex(I_blocks)) do b + I_b = findall(I_blocks[b]) + BlockIndexVector(Block(b), I_b) + end + return mortar(filter(!isempty, I′_blocks)) +end + # This handles non-blocked slices. # For example: # a = BlockSparseArray{Float64}([2, 2, 2, 2]) diff --git a/src/abstractblocksparsearray/unblockedsubarray.jl b/src/abstractblocksparsearray/unblockedsubarray.jl index 355ae824..fc80e92f 100644 --- a/src/abstractblocksparsearray/unblockedsubarray.jl +++ b/src/abstractblocksparsearray/unblockedsubarray.jl @@ -4,7 +4,10 @@ using BlockArrays: BlockArrays, Block, BlockIndexRange, BlockSlice using TypeParameterAccessors: TypeParameterAccessors, parenttype, similartype const UnblockedIndices = Union{ - Vector{<:Integer},BlockSlice{<:Block{1}},BlockSlice{<:BlockIndexRange{1}} + Vector{<:Integer}, + BlockSlice{<:Block{1}}, + BlockSlice{<:BlockIndexRange{1}}, + BlockSlice{<:BlockIndexVector}, } const UnblockedSubArray{T,N} = SubArray{ diff --git a/src/abstractblocksparsearray/views.jl b/src/abstractblocksparsearray/views.jl index f2451988..a7c93aed 100644 --- a/src/abstractblocksparsearray/views.jl +++ b/src/abstractblocksparsearray/views.jl @@ -92,11 +92,14 @@ end # TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`. to_block(I::Block{1}) = I to_block(I::BlockIndexRange{1}) = Block(I) +to_block(I::BlockIndexVector) = Block(I) to_block_indices(I::Block{1}) = Colon() to_block_indices(I::BlockIndexRange{1}) = only(I.indices) +to_block_indices(I::BlockIndexVector) = I.indices function Base.view( - a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Union{Block{1},BlockIndexRange{1}},N} + a::AbstractBlockSparseArray{<:Any,N}, + I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N}, ) where {N} return @views a[to_block.(I)...][to_block_indices.(I)...] end @@ -108,7 +111,7 @@ function Base.view( end function Base.view( a::SubArray{T,N,<:AbstractBlockSparseArray{T,N}}, - I::Vararg{Union{Block{1},BlockIndexRange{1}},N}, + I::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N}, ) where {T,N} return @views a[to_block.(I)...][to_block_indices.(I)...] end @@ -205,8 +208,21 @@ function BlockArrays.viewblock( end function to_blockindexrange( - a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, - I::Block{1}, + a::BlockSlice{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}}, I::Block{1} +) + # TODO: Ideally we would just use `a.blocks[I]` but that doesn't + # work right now. + return blocks(a.block)[Int(I)] +end +function to_blockindexrange( + a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange}}}, I::Block{1} +) + # TODO: Ideally we would just use `a.blocks[I]` but that doesn't + # work right now. + return blocks(a.blocks)[Int(I)] +end +function to_blockindexrange( + a::BlockIndices{<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}}, I::Block{1} ) # TODO: Ideally we would just use `a.blocks[I]` but that doesn't # work right now. @@ -245,33 +261,45 @@ function BlockArrays.viewblock( return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...) end -# Block slice of the result of slicing `@view a[2:5, 2:5]`. -# TODO: Move this to `BlockArraysExtensions`. -const BlockedSlice = BlockSlice{ - <:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}} -} - function Base.view( - a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, block::Union{Block{N},BlockIndexRange{N}}, ) where {T,N} return viewblock(a, block) end function Base.view( - a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, - block::Vararg{Union{Block{1},BlockIndexRange{1}},N}, + a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockIndexRangeSlice,N}}}, + block::Union{Block{N},BlockIndexRange{N}}, +) where {T,N} + return viewblock(a, block) +end +function Base.view( + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, + block::Vararg{Union{Block{1},BlockIndexRange{1},BlockIndexVector},N}, ) where {T,N} return viewblock(a, block...) end function BlockArrays.viewblock( - a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, block::Union{Block{N},BlockIndexRange{N}}, ) where {T,N} return viewblock(a, to_tuple(block)...) end + +blockedslice_blocks(x::BlockSlice) = x.block +blockedslice_blocks(x::BlockIndices) = x.blocks + # TODO: Define `@interface BlockSparseArrayInterface() viewblock`. function BlockArrays.viewblock( - a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, I::Vararg{Block{1},N}, ) where {T,N} # TODO: Use `reindex`, `to_indices`, etc. @@ -279,13 +307,15 @@ function BlockArrays.viewblock( # TODO: Ideally we would use this but it outputs a Vector, # not a range: # return parentindices(a)[dim].block[I[dim]] - return blocks(parentindices(a)[dim].block)[Int(I[dim])] + return blocks(blockedslice_blocks(parentindices(a)[dim]))[Int(I[dim])] end return @view parent(a)[brs...] end # TODO: Define `@interface BlockSparseArrayInterface() viewblock`. function BlockArrays.viewblock( - a::SubArray{T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{BlockedSlice,N}}}, + a::SubArray{ + T,N,<:AbstractBlockSparseArray{T,N},<:Tuple{Vararg{SubBlockSliceCollection,N}} + }, block::Vararg{BlockIndexRange{1},N}, ) where {T,N} return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...) diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 84c24ed8..c22907b8 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -40,6 +40,19 @@ function Base.to_indices( return @interface BlockSparseArrayInterface() to_indices(a, inds, I) end +function Base.to_indices( + a::AnyAbstractBlockSparseArray, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}} +) + return @interface BlockSparseArrayInterface() to_indices(a, inds, I) +end +# Fix ambiguity error with Base for logical indexing in Julia 1.10. +# TODO: Delete this once we drop support for Julia 1.10. +function Base.to_indices( + a::AnyAbstractBlockSparseArray, inds, I::Union{Tuple{BitArray{N}},Tuple{Array{Bool,N}}} +) where {N} + return @interface BlockSparseArrayInterface() to_indices(a, inds, I) +end + # a[[Block(2), Block(1)], [Block(2), Block(1)]] function Base.to_indices( a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:Block{1}},Vararg{Any}} diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 0015de0b..d45533cb 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -146,6 +146,14 @@ end return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end +@interface ::AbstractBlockSparseArrayInterface function Base.to_indices( + a, inds, I::Tuple{AbstractArray{Bool},Vararg{Any}} +) + bs1 = to_blockindices(inds[1], I[1]) + I1 = BlockIndices(bs1, blockedunitrange_getindices(inds[1], I[1])) + return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) +end + # Special case when there is no blocking. @interface ::AbstractBlockSparseArrayInterface function Base.to_indices( a, diff --git a/src/factorizations/lq.jl b/src/factorizations/lq.jl index 4a07cfa6..acb1046f 100644 --- a/src/factorizations/lq.jl +++ b/src/factorizations/lq.jl @@ -1,22 +1,29 @@ -using MatrixAlgebraKit: MatrixAlgebraKit, lq_compact!, lq_full! - -# TODO: this is a hardcoded for now to get around this function not being defined in the -# type domain -function default_blocksparse_lq_algorithm(A::AbstractMatrix; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - alg = MatrixAlgebraKit.LAPACK_HouseholderLQ(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) +using MatrixAlgebraKit: MatrixAlgebraKit, default_lq_algorithm, lq_compact!, lq_full! + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_lq_algorithm(A::AbstractBlockSparseMatrix; kwargs...) + return default_lq_algorithm(typeof(A); kwargs...) end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_compact!), A::AbstractBlockSparseMatrix; kwargs... + ::typeof(lq_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - return default_blocksparse_lq_algorithm(A; kwargs...) + return default_lq_algorithm(A; kwargs...) end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - ::typeof(lq_full!), A::AbstractBlockSparseMatrix; kwargs... + ::typeof(lq_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) - return default_blocksparse_lq_algorithm(A; kwargs...) + return default_lq_algorithm(A; kwargs...) +end + +function MatrixAlgebraKit.default_lq_algorithm( + A::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + alg = default_lq_algorithm(blocktype(A); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) end function similar_output( diff --git a/src/factorizations/polar.jl b/src/factorizations/polar.jl index 9b9c2831..d8abba7c 100644 --- a/src/factorizations/polar.jl +++ b/src/factorizations/polar.jl @@ -7,6 +7,34 @@ using MatrixAlgebraKit: right_polar!, svd_compact! +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(left_polar!), A::AbstractBlockSparseMatrix; kwargs... +) + return default_algorithm(left_polar!, typeof(A); kwargs...) +end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(left_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) +end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(right_polar!), A::AbstractBlockSparseMatrix; kwargs... +) + return default_algorithm(right_polar!, typeof(A); kwargs...) +end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_algorithm( + ::typeof(right_polar!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) +end + function MatrixAlgebraKit.check_input(::typeof(left_polar!), A::AbstractBlockSparseMatrix) @views for I in eachblockstoredindex(A) m, n = size(A[I]) @@ -46,14 +74,3 @@ function MatrixAlgebraKit.right_polar!(A::AbstractBlockSparseMatrix, alg::PolarV P = U * S * copy(U') return (P, Wᴴ) end - -function MatrixAlgebraKit.default_algorithm( - ::typeof(left_polar!), a::AbstractBlockSparseMatrix; kwargs... -) - return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) -end -function MatrixAlgebraKit.default_algorithm( - ::typeof(right_polar!), a::AbstractBlockSparseMatrix; kwargs... -) - return PolarViaSVD(default_algorithm(svd_compact!, a; kwargs...)) -end diff --git a/src/factorizations/qr.jl b/src/factorizations/qr.jl index 55a0b93e..74e587b7 100644 --- a/src/factorizations/qr.jl +++ b/src/factorizations/qr.jl @@ -1,25 +1,32 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_qr_algorithm, lq_compact!, lq_full!, qr_compact!, qr_full! -# TODO: this is a hardcoded for now to get around this function not being defined in the -# type domain +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_qr_algorithm(A::AbstractBlockSparseMatrix; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - alg = MatrixAlgebraKit.LAPACK_HouseholderQR(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) + return default_qr_algorithm(typeof(A); kwargs...) end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_compact!), A::AbstractBlockSparseMatrix; kwargs... + ::typeof(qr_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) return default_qr_algorithm(A; kwargs...) end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - ::typeof(qr_full!), A::AbstractBlockSparseMatrix; kwargs... + ::typeof(qr_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) return default_qr_algorithm(A; kwargs...) end +function MatrixAlgebraKit.default_qr_algorithm( + A::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + alg = default_qr_algorithm(blocktype(A); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) +end + function similar_output( ::typeof(qr_compact!), A, R_axis, alg::MatrixAlgebraKit.AbstractAlgorithm ) diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 187795ea..1d4d88bf 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -1,5 +1,14 @@ using MatrixAlgebraKit: MatrixAlgebraKit, default_svd_algorithm, svd_compact!, svd_full! +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +using MatrixAlgebraKit: TruncatedAlgorithm, select_truncation, svd_trunc! +function MatrixAlgebraKit.select_algorithm( + ::typeof(svd_trunc!), A::Type{<:AbstractBlockSparseMatrix}, alg; trunc=nothing, kwargs... +) + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) +end + """ BlockPermutedDiagonalAlgorithm(A::MatrixAlgebraKit.AbstractAlgorithm) @@ -12,27 +21,32 @@ struct BlockPermutedDiagonalAlgorithm{A<:MatrixAlgebraKit.AbstractAlgorithm} <: alg::A end -function MatrixAlgebraKit.default_svd_algorithm(A; kwargs...) - blocktype(A) <: StridedMatrix{<:LinearAlgebra.BLAS.BlasFloat} || - error("unsupported type: $(blocktype(A))") - # TODO: this is a hardcoded for now to get around this function not being defined in the - # type domain - # alg = MatrixAlgebraKit.default_algorithm(f, blocktype(A); kwargs...) - alg = MatrixAlgebraKit.LAPACK_DivideAndConquer(; kwargs...) - return BlockPermutedDiagonalAlgorithm(alg) +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. +function MatrixAlgebraKit.default_svd_algorithm(A::AbstractBlockSparseMatrix; kwargs...) + return default_svd_algorithm(typeof(A), kwargs...) end +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - f::typeof(svd_compact!), A::AbstractBlockSparseMatrix; kwargs... + f::typeof(svd_compact!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) return default_svd_algorithm(A; kwargs...) end + +# TODO: Delete once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/32 is merged. function MatrixAlgebraKit.default_algorithm( - f::typeof(svd_full!), A::AbstractBlockSparseMatrix; kwargs... + f::typeof(svd_full!), A::Type{<:AbstractBlockSparseMatrix}; kwargs... ) return default_svd_algorithm(A; kwargs...) end +function MatrixAlgebraKit.default_svd_algorithm( + A::Type{<:AbstractBlockSparseMatrix}; kwargs... +) + alg = default_svd_algorithm(blocktype(A); kwargs...) + return BlockPermutedDiagonalAlgorithm(alg) +end + function similar_output( ::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm ) diff --git a/test/test_basics.jl b/test/test_basics.jl index b0ee9a58..d8321aba 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -588,7 +588,11 @@ arrayts = (Array, JLArray) @views for b in [Block(1, 2), Block(2, 1)] a[b] = dev(randn(elt, size(a[b]))) end - for b in (a[2:4, 2:4], @view(a[2:4, 2:4])) + I = 2:4 + I′, J′ = falses.(size(a)) + I′[I] .= true + J′[I] .= true + for b in (a[I, I], @view(a[I, I]), a[I′, J′], @view(a[I′, J′])) @allowscalar @test b == Array(a)[2:4, 2:4] @test size(b) == (3, 3) @test blocksize(b) == (2, 2)