1
- using MatrixAlgebraKit: TruncationStrategy, diagview, eig_trunc!, eigh_trunc!, svd_trunc!
2
-
3
- function MatrixAlgebraKit. diagview (A:: BlockSparseMatrix{T,Diagonal{T,Vector{T}}} ) where {T}
4
- D = BlockSparseVector {T} (undef, axes (A, 1 ))
5
- for I in eachblockstoredindex (A)
6
- if == (Int .(Tuple (I))... )
7
- D[Tuple (I)[1 ]] = diagview (A[I])
8
- end
9
- end
10
- return D
11
- end
1
+ using MatrixAlgebraKit:
2
+ TruncationStrategy,
3
+ diagview,
4
+ eig_trunc!,
5
+ eigh_trunc!,
6
+ findtruncated,
7
+ svd_trunc!,
8
+ truncate!
12
9
13
10
"""
14
11
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
@@ -27,7 +24,7 @@ function MatrixAlgebraKit.truncate!(
27
24
strategy:: TruncationStrategy ,
28
25
)
29
26
# TODO assert blockdiagonal
30
- return MatrixAlgebraKit . truncate! (
27
+ return truncate! (
31
28
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy (strategy)
32
29
)
33
30
end
@@ -38,9 +35,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
38
35
(D, V):: NTuple{2,AbstractBlockSparseMatrix} ,
39
36
strategy:: TruncationStrategy ,
40
37
)
41
- return MatrixAlgebraKit. truncate! (
42
- $ f, (D, V), BlockPermutedDiagonalTruncationStrategy (strategy)
43
- )
38
+ return truncate! ($ f, (D, V), BlockPermutedDiagonalTruncationStrategy (strategy))
44
39
end
45
40
end
46
41
end
50
45
function MatrixAlgebraKit. findtruncated (
51
46
values:: AbstractVector , strategy:: BlockPermutedDiagonalTruncationStrategy
52
47
)
53
- ind = MatrixAlgebraKit . findtruncated (values, strategy. strategy)
48
+ ind = findtruncated (Vector ( values) , strategy. strategy)
54
49
indexmask = falses (length (values))
55
50
indexmask[ind] .= true
56
- return indexmask
51
+ return to_truncated_indices (values, indexmask)
52
+ end
53
+
54
+ # Allow customizing the indices output by `findtruncated`
55
+ # based on the type of `values`, for example to preserve
56
+ # a block or Kronecker structure.
57
+ to_truncated_indices (values, I) = I
58
+ function to_truncated_indices (values:: AbstractBlockVector , I:: AbstractVector{Bool} )
59
+ I′ = BlockedVector (I, blocklengths (axis (values)))
60
+ blocks = map (BlockRange (values)) do b
61
+ return _getindex (b, to_truncated_indices (values[b], I′[b]))
62
+ end
63
+ return blocks
57
64
end
58
65
59
66
function MatrixAlgebraKit. truncate! (
60
67
:: typeof (svd_trunc!),
61
68
(U, S, Vᴴ):: NTuple{3,AbstractBlockSparseMatrix} ,
62
69
strategy:: BlockPermutedDiagonalTruncationStrategy ,
63
70
)
64
- I = MatrixAlgebraKit. findtruncated (diagview (S), strategy)
71
+ I = MatrixAlgebraKit. findtruncated (diag (S), strategy)
65
72
return (U[:, I], S[I, I], Vᴴ[I, :])
66
73
end
67
74
for f in [:eig_trunc! , :eigh_trunc! ]
@@ -71,7 +78,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
71
78
(D, V):: NTuple{2,AbstractBlockSparseMatrix} ,
72
79
strategy:: BlockPermutedDiagonalTruncationStrategy ,
73
80
)
74
- I = MatrixAlgebraKit. findtruncated (diagview (D), strategy)
81
+ I = MatrixAlgebraKit. findtruncated (diag (D), strategy)
75
82
return (D[I, I], V[:, I])
76
83
end
77
84
end
0 commit comments