From 4e4429260413d3d6583f983f65af279e140312f3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 1 Jul 2025 18:15:05 -0400 Subject: [PATCH 1/3] Towards more general truncation and slicing --- Project.toml | 2 +- src/BlockArraysExtensions/blockedunitrange.jl | 8 +++- src/abstractblocksparsearray/linearalgebra.jl | 20 ++++++++- src/factorizations/truncation.jl | 45 +++++++++++-------- 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 17a00d07..318969d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.21" +version = "0.7.22" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index 2adf198f..cdd90161 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -314,6 +314,8 @@ end # `Base.getindex(a::Block, b...)`. _getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b) _getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...] +_getindex(a::Block{N}, b::Vararg{AbstractUnitRange{<:Integer},N}) where {N} = a[b...] +_getindex(a::Block{N}, b::Vararg{AbstractVector,N}) where {N} = BlockIndexVector(a, b) # Fix ambiguity. _getindex(a::Block{0}) = a[] @@ -372,7 +374,11 @@ function blockedunitrange_getindices( a::AbstractBlockedUnitRange, indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}}, ) - return mortar(map(b -> a[b], blocks(indices))) + blks = map(b -> a[b], blocks(indices)) + # Preserve any extra structure in the axes, like a + # Kronecker structure, symmetry sectors, etc. + ax = mortar_axis(map(b -> axis(a[b]), blocks(indices))) + return mortar(blks, (ax,)) end # This is a specialization of `BlockArrays.unblock`: diff --git a/src/abstractblocksparsearray/linearalgebra.jl b/src/abstractblocksparsearray/linearalgebra.jl index 0c9483fb..2b03de9e 100644 --- a/src/abstractblocksparsearray/linearalgebra.jl +++ b/src/abstractblocksparsearray/linearalgebra.jl @@ -1,4 +1,4 @@ -using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, norm, tr +using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, diag, norm, tr # Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184 # but also takes the dual of the axes. @@ -33,6 +33,24 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix) return tr_a end +# TODO: Define in DiagonalArrays.jl. +function diagaxis(a::AbstractArray) + LinearAlgebra.checksquare(a) + return axes(a, 1) +end +function LinearAlgebra.diag(a::AnyAbstractBlockSparseMatrix) + # TODO: Add `checkblocksquare` to also check it is square blockwise. + LinearAlgebra.checksquare(a) + diagaxes = map(blockdiagindices(a)) do I + return diagaxis(@view(a[I])) + end + r = blockrange(diagaxes) + stored_blocks = Dict(( + Tuple(I)[1] => diag(@view!(a[I])) for I in eachstoredblockdiagindex(a) + )) + return blocksparse(stored_blocks, (r,)) +end + # TODO: Define `SparseArraysBase.isdiag`, define as # `isdiag(blocks(a))`. function blockisdiag(a::AbstractArray) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 1ad4e72e..07b6924e 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -1,14 +1,11 @@ -using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc! - -function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T} - D = BlockSparseVector{T}(undef, axes(A, 1)) - for I in eachblockstoredindex(A) - if ==(Int.(Tuple(I))...) - D[Tuple(I)[1]] = diagview(A[I]) - end - end - return D -end +using MatrixAlgebraKit: + TruncationStrategy, + diagview, + eig_trunc!, + eigh_trunc!, + findtruncated, + svd_trunc!, + truncate! """ BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy) @@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!( strategy::TruncationStrategy, ) # TODO assert blockdiagonal - return MatrixAlgebraKit.truncate!( + return truncate!( svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy) ) end @@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!] (D, V)::NTuple{2,AbstractBlockSparseMatrix}, strategy::TruncationStrategy, ) - return MatrixAlgebraKit.truncate!( - $f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy) - ) + return truncate!($f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)) end end end @@ -50,10 +45,22 @@ end function MatrixAlgebraKit.findtruncated( values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy ) - ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy) + ind = findtruncated(Vector(values), strategy.strategy) indexmask = falses(length(values)) indexmask[ind] .= true - return indexmask + return to_truncated_indices(values, indexmask) +end + +# Allow customizing the indices output by `findtruncated` +# based on the type of `values`, for example to preserve +# a block or Kronecker structure. +to_truncated_indices(values, I) = I +function to_truncated_indices(values::AbstractBlockVector, I::AbstractVector{Bool}) + I′ = BlockedVector(I, blocklengths(axis(values))) + blocks = map(BlockRange(values)) do b + return _getindex(b, to_truncated_indices(values[b], I′[b])) + end + return blocks end function MatrixAlgebraKit.truncate!( @@ -61,7 +68,7 @@ function MatrixAlgebraKit.truncate!( (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) - I = MatrixAlgebraKit.findtruncated(diagview(S), strategy) + I = MatrixAlgebraKit.findtruncated(diag(S), strategy) return (U[:, I], S[I, I], Vᴴ[I, :]) end for f in [:eig_trunc!, :eigh_trunc!] @@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!] (D, V)::NTuple{2,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) - I = MatrixAlgebraKit.findtruncated(diagview(D), strategy) + I = MatrixAlgebraKit.findtruncated(diag(D), strategy) return (D[I, I], V[:, I]) end end From 5f4563d7d04d442582c982232ab9e19febb962e2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 2 Jul 2025 10:37:14 -0400 Subject: [PATCH 2/3] Preserve sector information better in slicing --- .../BlockArraysExtensions.jl | 6 ++ src/BlockArraysExtensions/blockedunitrange.jl | 6 +- test/test_factorizations.jl | 64 ++++++++++++------- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/src/BlockArraysExtensions/BlockArraysExtensions.jl b/src/BlockArraysExtensions/BlockArraysExtensions.jl index c4589782..4a194d24 100644 --- a/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -68,6 +68,12 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe end Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i) +# TODO: Move this to a `BlockArraysExtensions` library. +function blockedunitrange_getindices(a::AbstractBlockedUnitRange, indices::BlockIndices) + # TODO: Is this a good definition? It ignores `indices.indices`. + return a[indices.blocks] +end + # Generalization of to `BlockArrays._blockslice`: # https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L13-L14 # Used by `BlockArrays.unblock`, which is used in `to_indices` diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index cdd90161..604bef09 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -368,7 +368,11 @@ function blockedunitrange_getindices( a::AbstractBlockedUnitRange, indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}}, ) - return mortar(map(b -> a[b], blocks(indices))) + blks = map(b -> a[b], blocks(indices)) + # Preserve any extra structure in the axes, like a + # Kronecker structure, symmetry sectors, etc. + ax = mortar_axis(map(b -> axis(a[b]), blocks(indices))) + return mortar(blks, (ax,)) end function blockedunitrange_getindices( a::AbstractBlockedUnitRange, diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index 5113497b..651a8f65 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -146,20 +146,23 @@ test_params = Iterators.product(blockszs, eltypes) @test test_svd(a, usv_empty) # test blockdiagonal + rng = StableRNG(123) for i in LinearAlgebra.diagind(blocks(a)) I = CartesianIndices(blocks(a))[i] - a[Block(I.I...)] = rand(T, size(blocks(a)[i])) + a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i])) end usv = svd_compact(a) @test test_svd(a, usv) - perm = Random.randperm(length(m)) + rng = StableRNG(123) + perm = Random.randperm(rng, length(m)) b = a[Block.(perm), Block.(1:length(n))] usv = svd_compact(b) @test test_svd(b, usv) # test permuted blockdiagonal with missing row/col - I_removed = rand(eachblockstoredindex(b)) + rng = StableRNG(123) + I_removed = rand(rng, eachblockstoredindex(b)) c = copy(b) delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed)))) usv = svd_compact(c) @@ -176,20 +179,23 @@ end @test test_svd(a, usv_empty; full=true) # test blockdiagonal + rng = StableRNG(123) for i in LinearAlgebra.diagind(blocks(a)) I = CartesianIndices(blocks(a))[i] - a[Block(I.I...)] = rand(T, size(blocks(a)[i])) + a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i])) end usv = svd_full(a) @test test_svd(a, usv; full=true) - perm = Random.randperm(length(m)) + rng = StableRNG(123) + perm = Random.randperm(rng, length(m)) b = a[Block.(perm), Block.(1:length(n))] usv = svd_full(b) @test test_svd(b, usv; full=true) # test permuted blockdiagonal with missing row/col - I_removed = rand(eachblockstoredindex(b)) + rng = StableRNG(123) + I_removed = rand(rng, eachblockstoredindex(b)) c = copy(b) delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed)))) usv = svd_full(c) @@ -203,9 +209,10 @@ end a = BlockSparseArray{T}(undef, m, n) # test blockdiagonal + rng = StableRNG(123) for i in LinearAlgebra.diagind(blocks(a)) I = CartesianIndices(blocks(a))[i] - a[Block(I.I...)] = rand(T, size(blocks(a)[i])) + a[Block(I.I...)] = rand(rng, T, size(blocks(a)[i])) end minmn = min(size(a)...) @@ -236,7 +243,8 @@ end @test (V1ᴴ * V1ᴴ' ≈ LinearAlgebra.I) # test permuted blockdiagonal - perm = Random.randperm(length(m)) + rng = StableRNG(123) + perm = Random.randperm(rng, length(m)) b = a[Block.(perm), Block.(1:length(n))] for trunc in (truncrank(r), trunctol(atol)) U1, S1, V1ᴴ = svd_trunc(b; trunc) @@ -270,8 +278,9 @@ end @testset "qr_compact (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) - A[Block(1, 1)] = randn(T, i, k) - A[Block(2, 2)] = randn(T, j, l) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, i, k) + A[Block(2, 2)] = randn(rng, T, j, l) Q, R = qr_compact(A) @test Matrix(Q'Q) ≈ LinearAlgebra.I @test A ≈ Q * R @@ -281,8 +290,9 @@ end @testset "qr_full (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) - A[Block(1, 1)] = randn(T, i, k) - A[Block(2, 2)] = randn(T, j, l) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, i, k) + A[Block(2, 2)] = randn(rng, T, j, l) Q, R = qr_full(A) Q′, R′ = qr_full(Matrix(A)) @test size(Q) == size(Q′) @@ -296,8 +306,9 @@ end @testset "lq_compact" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) - A[Block(1, 1)] = randn(T, i, k) - A[Block(2, 2)] = randn(T, j, l) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, i, k) + A[Block(2, 2)] = randn(rng, T, j, l) L, Q = lq_compact(A) @test Matrix(Q * Q') ≈ LinearAlgebra.I @test A ≈ L * Q @@ -307,8 +318,9 @@ end @testset "lq_full" for T in (Float32, Float64, ComplexF32, ComplexF64) for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3] A = BlockSparseArray{T}(undef, ([i, j], [k, l])) - A[Block(1, 1)] = randn(T, i, k) - A[Block(2, 2)] = randn(T, j, l) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, i, k) + A[Block(2, 2)] = randn(rng, T, j, l) L, Q = lq_full(A) L′, Q′ = lq_full(Matrix(A)) @test size(L) == size(L′) @@ -321,8 +333,9 @@ end @testset "left_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) A = BlockSparseArray{T}(undef, ([3, 4], [2, 3])) - A[Block(1, 1)] = randn(T, 3, 2) - A[Block(2, 2)] = randn(T, 4, 3) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 3, 2) + A[Block(2, 2)] = randn(rng, T, 4, 3) U, C = left_polar(A) @test U * C ≈ A @@ -331,8 +344,9 @@ end @testset "right_polar (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) A = BlockSparseArray{T}(undef, ([2, 3], [3, 4])) - A[Block(1, 1)] = randn(T, 2, 3) - A[Block(2, 2)] = randn(T, 3, 4) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 3) + A[Block(2, 2)] = randn(rng, T, 3, 4) C, U = right_polar(A) @test C * U ≈ A @@ -341,8 +355,9 @@ end @testset "left_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) A = BlockSparseArray{T}(undef, ([3, 4], [2, 3])) - A[Block(1, 1)] = randn(T, 3, 2) - A[Block(2, 2)] = randn(T, 4, 3) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 3, 2) + A[Block(2, 2)] = randn(rng, T, 4, 3) for kind in (:polar, :qr, :svd) U, C = left_orth(A; kind) @@ -358,8 +373,9 @@ end @testset "right_orth (T=$T)" for T in (Float32, Float64, ComplexF32, ComplexF64) A = BlockSparseArray{T}(undef, ([2, 3], [3, 4])) - A[Block(1, 1)] = randn(T, 2, 3) - A[Block(2, 2)] = randn(T, 3, 4) + rng = StableRNG(123) + A[Block(1, 1)] = randn(rng, T, 2, 3) + A[Block(2, 2)] = randn(rng, T, 3, 4) for kind in (:lq, :polar, :svd) C, U = right_orth(A; kind) From 9d4edc1b5f900ee9583064d28b01f3b8243e8b68 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 2 Jul 2025 10:43:59 -0400 Subject: [PATCH 3/3] Namespace --- src/factorizations/truncation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 07b6924e..d173f88d 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -68,7 +68,7 @@ function MatrixAlgebraKit.truncate!( (U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) - I = MatrixAlgebraKit.findtruncated(diag(S), strategy) + I = findtruncated(diag(S), strategy) return (U[:, I], S[I, I], Vᴴ[I, :]) end for f in [:eig_trunc!, :eigh_trunc!] @@ -78,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!] (D, V)::NTuple{2,AbstractBlockSparseMatrix}, strategy::BlockPermutedDiagonalTruncationStrategy, ) - I = MatrixAlgebraKit.findtruncated(diag(D), strategy) + I = findtruncated(diag(D), strategy) return (D[I, I], V[:, I]) end end