1
1
using MatrixAlgebraKit:
2
+ MatrixAlgebraKit,
3
+ TruncatedAlgorithm,
2
4
TruncationStrategy,
3
5
diagview,
4
6
eig_trunc!,
@@ -8,42 +10,58 @@ using MatrixAlgebraKit:
8
10
truncate!
9
11
10
12
"""
11
- BlockPermutedDiagonalTruncationStrategy (strategy::TruncationStrategy)
13
+ BlockDiagonalTruncationStrategy (strategy::TruncationStrategy)
12
14
13
15
A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
14
- basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
15
- block-diagonal matrix.
16
+ basis, which is possible if the input matrix is a block-diagonal matrix.
16
17
"""
17
- struct BlockPermutedDiagonalTruncationStrategy {T<: TruncationStrategy } <: TruncationStrategy
18
+ struct BlockDiagonalTruncationStrategy {T<: TruncationStrategy } <: TruncationStrategy
18
19
strategy:: T
19
20
end
20
21
21
- function MatrixAlgebraKit. truncate! (
22
- :: typeof (svd_trunc!),
23
- (U, S, Vᴴ):: NTuple{3,AbstractBlockSparseMatrix} ,
24
- strategy:: TruncationStrategy ,
22
+ function BlockDiagonalTruncationStrategy (alg:: BlockPermutedDiagonalAlgorithm )
23
+ return BlockDiagonalTruncationStrategy (alg. strategy)
24
+ end
25
+
26
+ function MatrixAlgebraKit. svd_trunc! (
27
+ A:: AbstractBlockSparseMatrix ,
28
+ out,
29
+ alg:: TruncatedAlgorithm{<:BlockPermutedDiagonalAlgorithm} ,
25
30
)
26
- # TODO assert blockdiagonal
27
- return truncate! (
28
- svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy (strategy)
29
- )
31
+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
32
+ blockalg = BlockDiagonalAlgorithm (alg. alg)
33
+ blockstrategy = BlockDiagonalTruncationStrategy (alg. trunc)
34
+ Ud, S, Vᴴd = svd_trunc! (Ad, TruncatedAlgorithm (blockalg, blockstrategy))
35
+
36
+ U = transform_rows (Ud, invrowperm)
37
+ Vᴴ = transform_cols (Vᴴd, invcolperm)
38
+
39
+ return U, S, Vᴴ
30
40
end
41
+
31
42
for f in [:eig_trunc! , :eigh_trunc! ]
32
43
@eval begin
33
- function MatrixAlgebraKit. truncate! (
34
- :: typeof ( $ f) ,
35
- (D, V) :: NTuple{2,AbstractBlockSparseMatrix} ,
36
- strategy :: TruncationStrategy ,
44
+ function MatrixAlgebraKit. $f (
45
+ A :: AbstractBlockSparseMatrix ,
46
+ out ,
47
+ alg :: TruncatedAlgorithm{<:BlockPermutedDiagonalAlgorithm} ,
37
48
)
38
- return truncate! ($ f, (D, V), BlockPermutedDiagonalTruncationStrategy (strategy))
49
+ Ad, (invrowperm, invcolperm) = blockdiagonalize (A)
50
+ blockalg = BlockDiagonalAlgorithm (alg. alg)
51
+ blockstrategy = BlockDiagonalTruncationStrategy (alg. trunc)
52
+ Dd, Vd = $ f (Ad, TruncatedAlgorithm (blockalg, blockstrategy))
53
+
54
+ D = transform_rows (Dd, invrowperm)
55
+ V = transform_cols (Vd, invcolperm)
56
+ return D, V
39
57
end
40
58
end
41
59
end
42
60
43
61
# cannot use regular slicing here: I want to slice without altering blockstructure
44
62
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
45
63
function MatrixAlgebraKit. findtruncated (
46
- values:: AbstractVector , strategy:: BlockPermutedDiagonalTruncationStrategy
64
+ values:: AbstractVector , strategy:: BlockDiagonalTruncationStrategy
47
65
)
48
66
ind = findtruncated (Vector (values), strategy. strategy)
49
67
indexmask = falses (length (values))
66
84
function MatrixAlgebraKit. truncate! (
67
85
:: typeof (svd_trunc!),
68
86
(U, S, Vᴴ):: NTuple{3,AbstractBlockSparseMatrix} ,
69
- strategy:: BlockPermutedDiagonalTruncationStrategy ,
87
+ strategy:: BlockDiagonalTruncationStrategy ,
70
88
)
71
89
I = findtruncated (diag (S), strategy)
72
90
return (U[:, I], S[I, I], Vᴴ[I, :])
@@ -76,7 +94,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
76
94
function MatrixAlgebraKit. truncate! (
77
95
:: typeof ($ f),
78
96
(D, V):: NTuple{2,AbstractBlockSparseMatrix} ,
79
- strategy:: BlockPermutedDiagonalTruncationStrategy ,
97
+ strategy:: BlockDiagonalTruncationStrategy ,
80
98
)
81
99
I = findtruncated (diag (D), strategy)
82
100
return (D[I, I], V[:, I])
0 commit comments