Skip to content

Commit a992d38

Browse files
committed
Also fix factorizations
1 parent b175dfc commit a992d38

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

src/factorizations/truncation.jl

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using MatrixAlgebraKit:
2+
MatrixAlgebraKit,
3+
TruncatedAlgorithm,
24
TruncationStrategy,
35
diagview,
46
eig_trunc!,
@@ -8,42 +10,58 @@ using MatrixAlgebraKit:
810
truncate!
911

1012
"""
11-
BlockPermutedDiagonalTruncationStrategy(strategy::TruncationStrategy)
13+
BlockDiagonalTruncationStrategy(strategy::TruncationStrategy)
1214
1315
A wrapper for `TruncationStrategy` that implements the wrapped strategy on a block-by-block
14-
basis, which is possible if the input matrix is a block-diagonal matrix or a block permuted
15-
block-diagonal matrix.
16+
basis, which is possible if the input matrix is a block-diagonal matrix.
1617
"""
17-
struct BlockPermutedDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
18+
struct BlockDiagonalTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy
1819
strategy::T
1920
end
2021

21-
function MatrixAlgebraKit.truncate!(
22-
::typeof(svd_trunc!),
23-
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
24-
strategy::TruncationStrategy,
22+
function BlockDiagonalTruncationStrategy(alg::BlockPermutedDiagonalAlgorithm)
23+
return BlockDiagonalTruncationStrategy(alg.strategy)
24+
end
25+
26+
function MatrixAlgebraKit.svd_trunc!(
27+
A::AbstractBlockSparseMatrix,
28+
out,
29+
alg::TruncatedAlgorithm{<:BlockPermutedDiagonalAlgorithm},
2530
)
26-
# TODO assert blockdiagonal
27-
return truncate!(
28-
svd_trunc!, (U, S, Vᴴ), BlockPermutedDiagonalTruncationStrategy(strategy)
29-
)
31+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
32+
blockalg = BlockDiagonalAlgorithm(alg.alg)
33+
blockstrategy = BlockDiagonalTruncationStrategy(alg.trunc)
34+
Ud, S, Vᴴd = svd_trunc!(Ad, TruncatedAlgorithm(blockalg, blockstrategy))
35+
36+
U = transform_rows(Ud, invrowperm)
37+
Vᴴ = transform_cols(Vᴴd, invcolperm)
38+
39+
return U, S, Vᴴ
3040
end
41+
3142
for f in [:eig_trunc!, :eigh_trunc!]
3243
@eval begin
33-
function MatrixAlgebraKit.truncate!(
34-
::typeof($f),
35-
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
36-
strategy::TruncationStrategy,
44+
function MatrixAlgebraKit.$f(
45+
A::AbstractBlockSparseMatrix,
46+
out,
47+
alg::TruncatedAlgorithm{<:BlockPermutedDiagonalAlgorithm},
3748
)
38-
return truncate!($f, (D, V), BlockPermutedDiagonalTruncationStrategy(strategy))
49+
Ad, (invrowperm, invcolperm) = blockdiagonalize(A)
50+
blockalg = BlockDiagonalAlgorithm(alg.alg)
51+
blockstrategy = BlockDiagonalTruncationStrategy(alg.trunc)
52+
Dd, Vd = $f(Ad, TruncatedAlgorithm(blockalg, blockstrategy))
53+
54+
D = transform_rows(Dd, invrowperm)
55+
V = transform_cols(Vd, invcolperm)
56+
return D, V
3957
end
4058
end
4159
end
4260

4361
# cannot use regular slicing here: I want to slice without altering blockstructure
4462
# solution: use boolean indexing and slice the mask, effectively cheaply inverting the map
4563
function MatrixAlgebraKit.findtruncated(
46-
values::AbstractVector, strategy::BlockPermutedDiagonalTruncationStrategy
64+
values::AbstractVector, strategy::BlockDiagonalTruncationStrategy
4765
)
4866
ind = findtruncated(Vector(values), strategy.strategy)
4967
indexmask = falses(length(values))
@@ -66,7 +84,7 @@ end
6684
function MatrixAlgebraKit.truncate!(
6785
::typeof(svd_trunc!),
6886
(U, S, Vᴴ)::NTuple{3,AbstractBlockSparseMatrix},
69-
strategy::BlockPermutedDiagonalTruncationStrategy,
87+
strategy::BlockDiagonalTruncationStrategy,
7088
)
7189
I = findtruncated(diag(S), strategy)
7290
return (U[:, I], S[I, I], Vᴴ[I, :])
@@ -76,7 +94,7 @@ for f in [:eig_trunc!, :eigh_trunc!]
7694
function MatrixAlgebraKit.truncate!(
7795
::typeof($f),
7896
(D, V)::NTuple{2,AbstractBlockSparseMatrix},
79-
strategy::BlockPermutedDiagonalTruncationStrategy,
97+
strategy::BlockDiagonalTruncationStrategy,
8098
)
8199
I = findtruncated(diag(D), strategy)
82100
return (D[I, I], V[:, I])

0 commit comments

Comments
 (0)