Skip to content

Commit 4e44292

Browse files
committed
Towards more general truncation and slicing
1 parent 21eb0e9 commit 4e44292

File tree

4 files changed

+53
-22
lines changed

4 files changed

+53
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.21"
4+
version = "0.7.22"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ end
314314
# `Base.getindex(a::Block, b...)`.
315315
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
316316
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
317+
_getindex(a::Block{N}, b::Vararg{AbstractUnitRange{<:Integer},N}) where {N} = a[b...]
318+
_getindex(a::Block{N}, b::Vararg{AbstractVector,N}) where {N} = BlockIndexVector(a, b)
317319
# Fix ambiguity.
318320
_getindex(a::Block{0}) = a[]
319321

@@ -372,7 +374,11 @@ function blockedunitrange_getindices(
372374
a::AbstractBlockedUnitRange,
373375
indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
374376
)
375-
return mortar(map(b -> a[b], blocks(indices)))
377+
blks = map(b -> a[b], blocks(indices))
378+
# Preserve any extra structure in the axes, like a
379+
# Kronecker structure, symmetry sectors, etc.
380+
ax = mortar_axis(map(b -> axis(a[b]), blocks(indices)))
381+
return mortar(blks, (ax,))
376382
end
377383

378384
# This is a specialization of `BlockArrays.unblock`:

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, norm, tr
1+
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, diag, norm, tr
22

33
# Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184
44
# but also takes the dual of the axes.
@@ -33,6 +33,24 @@ function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
3333
return tr_a
3434
end
3535

36+
# TODO: Define in DiagonalArrays.jl.
37+
function diagaxis(a::AbstractArray)
38+
LinearAlgebra.checksquare(a)
39+
return axes(a, 1)
40+
end
41+
function LinearAlgebra.diag(a::AnyAbstractBlockSparseMatrix)
42+
# TODO: Add `checkblocksquare` to also check it is square blockwise.
43+
LinearAlgebra.checksquare(a)
44+
diagaxes = map(blockdiagindices(a)) do I
45+
return diagaxis(@view(a[I]))
46+
end
47+
r = blockrange(diagaxes)
48+
stored_blocks = Dict((
49+
Tuple(I)[1] => diag(@view!(a[I])) for I in eachstoredblockdiagindex(a)
50+
))
51+
return blocksparse(stored_blocks, (r,))
52+
end
53+
3654
# TODO: Define `SparseArraysBase.isdiag`, define as
3755
# `isdiag(blocks(a))`.
3856
function blockisdiag(a::AbstractArray)

src/factorizations/truncation.jl

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
2-
3-
function MatrixAlgebraKit.diagview(A::BlockSparseMatrix{T,Diagonal{T,Vector{T}}}) where {T}
4-
D = BlockSparseVector{T}(undef, axes(A, 1))
5-
for I in eachblockstoredindex(A)
6-
if ==(Int.(Tuple(I))...)
7-
D[Tuple(I)[1]] = diagview(A[I])
8-
end
9-
end
10-
return D
11-
end
1+
using MatrixAlgebraKit:
2+
TruncationStrategy,
3+
diagview,
4+
eig_trunc!,
5+
eigh_trunc!,
6+
findtruncated,
7+
svd_trunc!,
8+
truncate!
129

1310
"""
1411
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
@@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!(
2724
strategy::TruncationStrategy,
2825
)
2926
# TODO assert blockdiagonal
30-
return MatrixAlgebraKit.truncate!(
27+
return truncate!(
3128
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
3229
)
3330
end
@@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
3835
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
3936
strategy::TruncationStrategy,
4037
)
41-
return MatrixAlgebraKit.truncate!(
42-
$f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy)
43-
)
38+
return truncate!($f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy))
4439
end
4540
end
4641
end
@@ -50,18 +45,30 @@ end
5045
function MatrixAlgebraKit.findtruncated(
5146
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
5247
)
53-
ind = MatrixAlgebraKit.findtruncated(values, strategy.strategy)
48+
ind = findtruncated(Vector(values), strategy.strategy)
5449
indexmask = falses(length(values))
5550
indexmask[ind] .= true
56-
return indexmask
51+
return to_truncated_indices(values, indexmask)
52+
end
53+
54+
# Allow customizing the indices output by `findtruncated`
55+
# based on the type of `values`, for example to preserve
56+
# a block or Kronecker structure.
57+
to_truncated_indices(values, I) = I
58+
function to_truncated_indices(values::AbstractBlockVector, I::AbstractVector{Bool})
59+
I′ = BlockedVector(I, blocklengths(axis(values)))
60+
blocks = map(BlockRange(values)) do b
61+
return _getindex(b, to_truncated_indices(values[b], I′[b]))
62+
end
63+
return blocks
5764
end
5865

5966
function MatrixAlgebraKit.truncate!(
6067
::typeof(svd_trunc!),
6168
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
6269
strategy::BlockPermutedDiagonalTruncationStrategy,
6370
)
64-
I = MatrixAlgebraKit.findtruncated(diagview(S), strategy)
71+
I = MatrixAlgebraKit.findtruncated(diag(S), strategy)
6572
return (U[:, I], S[I, I], Vᴴ[I, :])
6673
end
6774
for f in [:eig_trunc!, :eigh_trunc!]
@@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
7178
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
7279
strategy::BlockPermutedDiagonalTruncationStrategy,
7380
)
74-
I = MatrixAlgebraKit.findtruncated(diagview(D), strategy)
81+
I = MatrixAlgebraKit.findtruncated(diag(D), strategy)
7582
return (D[I, I], V[:, I])
7683
end
7784
end

0 commit comments

Comments
 (0)