|
1 | 1 | using MatrixAlgebraKit: TruncationStrategy, diagview
|
2 | 2 |
|
| 3 | +""" |
| 4 | + BlockPermutedDiagonalTruncationStrategy(strategy::MatrixAlgebraKit.TruncationStrategy) |
| 5 | +
|
| 6 | +A wrapper for `MatrixAlgebraKit.TruncationStrategy` that implements the wrapped strategy on |
| 7 | +a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or |
| 8 | +a block permuted block-diagonal matrix. |
| 9 | +""" |
| 10 | +struct BlockPermutedDiagonalTruncationStrategy{T<:MatrixAlgebraKit.TruncationStrategy} <: |
| 11 | + MatrixAlgebraKit.TruncationStrategy |
| 12 | + strategy::T |
| 13 | +end |
| 14 | + |
3 | 15 | const TBlockUSVᴴ = Tuple{
|
4 | 16 | <:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix,<:AbstractBlockSparseMatrix
|
5 | 17 | }
|
6 | 18 |
|
7 | 19 | function MatrixAlgebraKit.truncate!(
|
8 |
| - ::typeof(svd_trunc!), (U, S, Vᴴ)::TBlockUSVᴴ, strategy::TruncationStrategy |
| 20 | + ::typeof(svd_trunc!), |
| 21 | + (U, S, Vᴴ)::TBlockUSVᴴ, |
| 22 | + strategy::BlockPermutedDiagonalTruncationStrategy, |
9 | 23 | )
|
10 |
| - ind = MatrixAlgebraKit.findtruncated(diagview(S), strategy) |
| 24 | + ind = MatrixAlgebraKit.findtruncated(diagview(S), strategy.strategy) |
11 | 25 | # cannot use regular slicing here: I want to slice without altering blockstructure
|
12 | 26 | # solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
|
13 | 27 | indexmask = falses(size(S, 1))
|
@@ -60,3 +74,4 @@ function MatrixAlgebraKit.truncate!(
|
60 | 74 |
|
61 | 75 | return Ũ, S̃, Ṽᴴ
|
62 | 76 | end
|
| 77 | + |
0 commit comments